You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
7.9 KiB

4 years ago
4 years ago
4 years ago
4 years ago
  1. use std::time::Duration;
  2. use std::pin::Pin;
  3. use futures::future::Future;
  4. use log::{info, warn};
  5. use async_std::sync::{Arc, RwLock};
  6. use async_std::io::prelude::{ReadExt, WriteExt};
  7. use std::collections::HashMap;
  8. use serde::Serialize;
  9. use serde::de::DeserializeOwned;
  10. use crate::interserver::{ServerId, InterserverActor};
  11. use libpso::crypto::{PSOCipher, NullCipher, CipherError};
  12. use crate::serverstate::{ServerState, SendServerPacket, RecvServerPacket};
  13. use entity::gateway::entitygateway::EntityGateway;
  14. use async_std::channel;
  15. use std::fmt::Debug;
  16. #[derive(Debug)]
  17. enum MessageReceiverError {
  18. //InvalidSize,
  19. InvalidPayload,
  20. //NetworkError(std::io::Error),
  21. Disconnected,
  22. }
  23. struct MessageReceiver {
  24. socket: async_std::net::TcpStream,
  25. }
  26. impl MessageReceiver {
  27. fn new(socket: async_std::net::TcpStream) -> MessageReceiver {
  28. MessageReceiver {
  29. socket,
  30. }
  31. }
  32. async fn recv<R: serde::de::DeserializeOwned + std::fmt::Debug>(&mut self) -> Result<R, MessageReceiverError> {
  33. let mut size_buf = [0u8; 4];
  34. self.socket.read_exact(&mut size_buf).await.map_err(|_| MessageReceiverError::Disconnected)?;
  35. let size = u32::from_le_bytes(size_buf) as usize;
  36. let mut payload = vec![0u8; size];
  37. self.socket.read_exact(&mut payload).await.map_err(|_| MessageReceiverError::Disconnected)?;
  38. let payload = String::from_utf8(payload).map_err(|_| MessageReceiverError::InvalidPayload)?;
  39. let msg = serde_json::from_str(&payload).map_err(|_| MessageReceiverError::InvalidPayload)?;
  40. Ok(msg)
  41. }
  42. }
  43. async fn interserver_recv_loop<STATE, S, R, E>(mut state: STATE, server_id: ServerId, socket: async_std::net::TcpStream, ships: Arc<RwLock<HashMap<ServerId, channel::Sender<S>>>>)
  44. where
  45. STATE: InterserverActor<SendMessage=S, RecvMessage=R, Error=E> + Send,
  46. S: serde::Serialize + Debug + Send,
  47. R: serde::de::DeserializeOwned + Debug + Send,
  48. E: Debug + Send,
  49. {
  50. let mut msg_receiver = MessageReceiver::new(socket);
  51. loop {
  52. match msg_receiver.recv::<R>().await {
  53. Ok(msg) => {
  54. info!("[interserver recv {:?}] {:?}", server_id, msg);
  55. match state.on_action(server_id, msg).await {
  56. Ok(response) => {
  57. for resp in response {
  58. ships
  59. .read()
  60. .await
  61. .get(&resp.0)
  62. .unwrap()
  63. .send(resp.1)
  64. .await
  65. .unwrap();
  66. }
  67. },
  68. Err(err) => {
  69. warn!("[interserver recv {:?}] error {:?}", server_id, err);
  70. }
  71. }
  72. },
  73. Err(err) => {
  74. if let MessageReceiverError::Disconnected = err {
  75. info!("[interserver recv {:?}] disconnected", server_id);
  76. for (_, _sender) in ships.read().await.iter() {
  77. for pkt in state.on_disconnect(server_id).await {
  78. ships
  79. .read()
  80. .await
  81. .get(&pkt.0)
  82. .unwrap()
  83. .send(pkt.1)
  84. .await
  85. .unwrap();
  86. }
  87. }
  88. ships
  89. .write()
  90. .await
  91. .remove(&server_id);
  92. break;
  93. }
  94. info!("[interserver recv {:?}] error {:?}", server_id, err);
  95. }
  96. }
  97. }
  98. }
  99. async fn interserver_send_loop<S>(server_id: ServerId, mut socket: async_std::net::TcpStream, to_send: channel::Receiver<S>)
  100. where
  101. S: serde::Serialize + std::fmt::Debug,
  102. {
  103. loop {
  104. let msg = to_send.recv().await.unwrap();
  105. let payload = serde_json::to_string(&msg);
  106. if let Ok(payload) = payload {
  107. let len_bytes = u32::to_le_bytes(payload.len() as u32);
  108. if let Err(err) = socket.write_all(&len_bytes).await {
  109. warn!("[interserver send {:?}] failed: {:?}", server_id, err);
  110. break;
  111. }
  112. if let Err(err) = socket.write_all(payload.as_bytes()).await {
  113. warn!("[interserver send {:?}] failed: {:?}", server_id, err);
  114. break;
  115. }
  116. }
  117. }
  118. }
  119. pub async fn run_interserver_listen<STATE, S, R, E>(mut state: STATE, port: u16)
  120. where
  121. STATE: InterserverActor<SendMessage=S, RecvMessage=R, Error=E> + Send + 'static,
  122. S: serde::Serialize + Debug + Send + 'static,
  123. R: serde::de::DeserializeOwned + Debug + Send,
  124. E: Debug + Send,
  125. {
  126. let listener = async_std::net::TcpListener::bind(&std::net::SocketAddr::from((std::net::Ipv4Addr::new(0,0,0,0), port))).await.unwrap();
  127. let mut id = 0;
  128. let ships = Arc::new(RwLock::new(HashMap::new()));
  129. loop {
  130. let (socket, addr) = listener.accept().await.unwrap();
  131. info!("[interserver listen] new server: {:?} {:?}", socket, addr);
  132. id += 1;
  133. let server_id = crate::interserver::ServerId(id);
  134. let (client_tx, client_rx) = async_std::channel::unbounded();
  135. state.set_sender(server_id, client_tx.clone()).await;
  136. ships
  137. .write()
  138. .await
  139. .insert(server_id, client_tx.clone());
  140. for msg in state.on_connect(server_id).await {
  141. if let Some(ship_sender) = ships.read().await.get(&msg.0) {
  142. ship_sender.send(msg.1).await.unwrap();
  143. }
  144. }
  145. let rstate = state.clone();
  146. let rsocket = socket.clone();
  147. let rships = ships.clone();
  148. async_std::task::spawn(async move {
  149. interserver_recv_loop(rstate, server_id, rsocket, rships).await;
  150. });
  151. async_std::task::spawn(async move {
  152. interserver_send_loop(server_id, socket, client_rx).await;
  153. });
  154. }
  155. }
  156. pub async fn run_interserver_connect<STATE, S, R, E>(mut state: STATE, ip: std::net::Ipv4Addr, port: u16)
  157. where
  158. STATE: InterserverActor<SendMessage=S, RecvMessage=R, Error=E> + Send + 'static,
  159. S: serde::Serialize + Debug + Send + 'static,
  160. R: serde::de::DeserializeOwned + Debug + Send,
  161. E: Debug + Send,
  162. {
  163. let mut id = 0;
  164. loop {
  165. info!("[interserver connect] trying to connect to server");
  166. let socket = match async_std::net::TcpStream::connect((ip, port)).await {
  167. Ok(socket) => socket,
  168. Err(err) => {
  169. info!("err trying to connect to loginserv {:?}", err);
  170. async_std::task::sleep(std::time::Duration::from_secs(10)).await;
  171. continue;
  172. }
  173. };
  174. id += 1;
  175. let server_id = crate::interserver::ServerId(id);
  176. info!("[interserver connect] found loginserv: {:?} {:?}", server_id, socket);
  177. let (client_tx, client_rx) = async_std::channel::unbounded();
  178. state.set_sender(server_id, client_tx.clone()).await;
  179. for msg in state.on_connect(server_id).await {
  180. client_tx.send(msg.1).await.unwrap();
  181. }
  182. let other_server = vec![(server_id, client_tx.clone())].into_iter().collect();
  183. let rstate = state.clone();
  184. let rsocket = socket.clone();
  185. async_std::task::spawn(async move {
  186. interserver_recv_loop(rstate, server_id, rsocket, Arc::new(RwLock::new(other_server))).await;
  187. });
  188. let ssocket = socket.clone();
  189. async_std::task::spawn(async move {
  190. interserver_send_loop(server_id, ssocket, client_rx).await;
  191. });
  192. let mut buf = [0u8; 1];
  193. loop {
  194. let peek = socket.peek(&mut buf).await;
  195. if let Ok(0) = peek {
  196. break
  197. }
  198. }
  199. }
  200. }