324 lines
12 KiB
Rust

use std::pin::pin;
use futures::future::Future;
use log::{trace, info, warn};
use async_std::sync::{Arc, Mutex};
use async_std::io::prelude::{ReadExt, WriteExt};
use std::collections::HashMap;
use std::pin::Pin;
use libpso::crypto::{PSOCipher, NullCipher, CipherError};
use libpso::PacketParseError;
use crate::common::serverstate::ClientId;
use crate::common::serverstate::{RecvServerPacket, SendServerPacket, ServerState, OnConnect};
#[derive(Debug)]
pub enum NetworkError {
CouldNotSend,
CipherError(CipherError),
PacketParseError(PacketParseError),
IOError(std::io::Error),
DataNotReady,
ClientDisconnected,
}
impl From<CipherError> for NetworkError {
fn from(err: CipherError) -> NetworkError {
NetworkError::CipherError(err)
}
}
impl From<std::io::Error> for NetworkError {
fn from(err: std::io::Error) -> NetworkError {
NetworkError::IOError(err)
}
}
impl From<PacketParseError> for NetworkError {
fn from(err: PacketParseError) -> NetworkError {
NetworkError::PacketParseError(err)
}
}
pub struct PacketReceiver<C: PSOCipher> {
socket: async_std::net::TcpStream,
//cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
cipher: C,
recv_buffer: Vec<u8>,
incoming_data: Vec<u8>,
}
impl<C: PSOCipher> PacketReceiver<C> {
pub fn new(socket: async_std::net::TcpStream, cipher: C) -> PacketReceiver<C> {
PacketReceiver {
socket,
cipher,
recv_buffer: Vec::new(),
incoming_data: Vec::new(),
}
}
async fn fill_recv_buffer(&mut self) -> Result<(), NetworkError> {
let mut data = [0u8; 0x8000];
let mut socket = self.socket.clone();
let len = socket.read(&mut data).await?;
if len == 0 {
return Err(NetworkError::ClientDisconnected);
}
self.recv_buffer.extend_from_slice(&data[..len]);
let mut dec_buf = {
//let mut cipher = self.cipher.lock().await;
let block_chunk_len = self.recv_buffer.len() / self.cipher.block_size() * self.cipher.block_size();
let buf = self.recv_buffer.drain(..block_chunk_len).collect();
self.cipher.decrypt(&buf)?
};
self.incoming_data.append(&mut dec_buf);
Ok(())
}
pub async fn recv_pkts<R: RecvServerPacket + std::fmt::Debug>(&mut self) -> Result<Vec<R>, NetworkError> {
self.fill_recv_buffer().await?;
let mut result = Vec::new();
loop {
if self.incoming_data.len() < 2 {
break;
}
let pkt_size = u16::from_le_bytes([self.incoming_data[0], self.incoming_data[1]]) as usize;
let mut pkt_len = pkt_size;
while pkt_len % self.cipher.block_size() != 0 {
pkt_len += 1;
}
if pkt_len > self.incoming_data.len() {
break;
}
let pkt_data = self.incoming_data.drain(..pkt_len).collect::<Vec<_>>();
trace!("[recv buf] {:?}", pkt_data);
let pkt = match R::from_bytes(&pkt_data[..pkt_size]) {
Ok(p) => p,
Err(err) => {
warn!("error RecvServerPacket::from_bytes: {:?}", err);
continue
},
};
result.push(pkt);
}
Ok(result)
}
}
/*
async fn send_pkt<S: SendServerPacket + Send + std::fmt::Debug>(socket: Arc<async_std::net::TcpStream>,
cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>, pkt: S)
-> Result<(), NetworkError>
{
let buf = pkt.as_bytes();
trace!("[send buf] {:?}", buf);
let cbuf = cipher.lock().await.encrypt(&buf)?;
let mut ssock = &*socket;
ssock.write_all(&cbuf).await?;
Ok(())
}
enum ClientAction<S, R> {
NewClient(ClientId, async_std::channel::Sender<S>),
Packet(ClientId, R),
Disconnect(ClientId),
}
enum ServerStateAction<S> {
Cipher(Box<dyn PSOCipher + Send + Sync>, Box<dyn PSOCipher + Send + Sync>),
Packet(S),
Disconnect,
}
fn client_recv_loop<S, R>(client_id: ClientId,
socket: Arc<async_std::net::TcpStream>,
cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
server_sender: async_std::channel::Sender<ClientAction<ServerStateAction<S>, R>>,
client_sender: async_std::channel::Sender<ServerStateAction<S>>)
where
S: SendServerPacket + std::fmt::Debug + Send + 'static,
R: RecvServerPacket + std::fmt::Debug + Send + 'static,
{
async_std::task::spawn(async move {
server_sender.send(ClientAction::NewClient(client_id, client_sender)).await.unwrap();
/*
let mut pkt_receiver = PacketReceiver::new(*socket, cipher);
loop {
match pkt_receiver.recv_pkts().await {
Ok(pkts) => {
for pkt in pkts {
info!("[recv from {:?}] {:#?}", client_id, pkt);
server_sender.send(ClientAction::Packet(client_id, pkt)).await.unwrap();
}
},
Err(err) => {
match err {
NetworkError::ClientDisconnected => {
trace!("[client disconnected] {:?}", client_id);
server_sender.send(ClientAction::Disconnect(client_id)).await.unwrap();
break;
}
_ => {
warn!("[client {:?} recv error] {:?}", client_id, err);
}
}
}
}
}
*/
});
}
fn client_send_loop<S>(client_id: ClientId,
socket: Arc<async_std::net::TcpStream>,
cipher_in: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
cipher_out: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
client_receiver: async_std::channel::Receiver<ServerStateAction<S>>)
where
S: SendServerPacket + std::fmt::Debug + Send + 'static,
{
async_std::task::spawn(async move {
loop {
let action = client_receiver.recv().await.unwrap();
match action {
ServerStateAction::Cipher(inc, outc) => {
*cipher_in.lock().await = inc;
*cipher_out.lock().await = outc;
}
ServerStateAction::Packet(pkt) => {
info!("[send to {:?}] {:#?}", client_id, pkt);
if let Err(err) = send_pkt(socket.clone(), cipher_out.clone(), pkt).await {
warn!("[client {:?} send error ] {:?}", client_id, err);
}
},
ServerStateAction::Disconnect => {
break;
}
};
}
});
}
fn state_client_loop<STATE, S, R, E>(state: Arc<Mutex<STATE>>,
server_state_receiver: async_std::channel::Receiver<ClientAction<ServerStateAction<S>, R>>) where
STATE: ServerState<SendPacket=S, RecvPacket=R, PacketError=E> + Send + 'static,
S: SendServerPacket + std::fmt::Debug + Send + 'static,
R: RecvServerPacket + std::fmt::Debug + Send + 'static,
E: std::fmt::Debug + Send,
{
async_std::task::spawn(async move {
let mut clients = HashMap::new();
loop {
let action = server_state_receiver.recv().await.unwrap();
let mut state = state.lock().await;
match action {
ClientAction::NewClient(client_id, sender) => {
let actions = state.on_connect(client_id).await;
match actions {
Ok(actions) => {
for action in actions {
match action {
OnConnect::Cipher((inc, outc)) => {
sender.send(ServerStateAction::Cipher(inc, outc)).await.unwrap();
},
OnConnect::Packet(pkt) => {
sender.send(ServerStateAction::Packet(pkt)).await.unwrap();
}
}
}
},
Err(err) => {
warn!("[client {:?} state on_connect error] {:?}", client_id, err);
}
}
clients.insert(client_id, sender);
},
ClientAction::Packet(client_id, pkt) => {
let pkts = state.handle(client_id, &pkt).await;
match pkts {
Ok(pkts) => {
for (client_id, pkt) in pkts {
if let Some(client) = clients.get_mut(&client_id) {
client.send(ServerStateAction::Packet(pkt)).await.unwrap();
}
}
},
Err(err) => {
warn!("[client {:?} state handler error] {:?}", client_id, err);
}
}
},
ClientAction::Disconnect(client_id) => {
let pkts = state.on_disconnect(client_id).await;
match pkts {
Ok(pkts) => {
for (client_id, pkt) in pkts {
if let Some(client) = clients.get_mut(&client_id) {
client.send(ServerStateAction::Packet(pkt)).await.unwrap();
}
}
if let Some(client) = clients.get_mut(&client_id) {
client.send(ServerStateAction::Disconnect).await.unwrap();
}
}
Err(err) => {
warn!("[client {:?} state on_disconnect error] {:?}", client_id, err);
}
}
}
}
}
});
}
pub fn client_accept_mainloop<STATE, S, R, E>(state: Arc<Mutex<STATE>>, client_port: u16) -> Pin<Box<dyn Future<Output = ()>>>
where
STATE: ServerState<SendPacket=S, RecvPacket=R, PacketError=E> + Send + 'static,
S: SendServerPacket + std::fmt::Debug + Send + Sync + 'static,
R: RecvServerPacket + std::fmt::Debug + Send + Sync + 'static,
E: std::fmt::Debug + Send,
{
Box::pin(async_std::task::spawn(async move {
let listener = async_std::net::TcpListener::bind(&std::net::SocketAddr::from((std::net::Ipv4Addr::new(0,0,0,0), client_port))).await.unwrap();
let mut id = 0;
let (server_state_sender, server_state_receiver) = async_std::channel::bounded(1024);
state_client_loop(state, server_state_receiver);
loop {
let (sock, addr) = listener.accept().await.unwrap();
id += 1;
let client_id = crate::common::serverstate::ClientId(id);
info!("new client {:?} {:?} {:?}", client_id, sock, addr);
let (client_sender, client_receiver) = async_std::channel::bounded(64);
let socket = Arc::new(sock);
let cipher_in: Arc<Mutex<Box<dyn PSOCipher + Send>>> = Arc::new(Mutex::new(Box::new(NullCipher {})));
let cipher_out: Arc<Mutex<Box<dyn PSOCipher + Send>>> = Arc::new(Mutex::new(Box::new(NullCipher {})));
client_recv_loop(client_id, socket.clone(), cipher_in.clone(), server_state_sender.clone(), client_sender);
client_send_loop(client_id, socket.clone(), cipher_in.clone(), cipher_out.clone(), client_receiver);
}
}))
}
*/