From eca6ddf3b8f9e5c5037e9d1a2122ca1969624ba3 Mon Sep 17 00:00:00 2001 From: Donald Knuth Date: Mon, 7 Feb 2022 13:38:16 -0500 Subject: [PATCH] Scaffold message passing --- src/client.rs | 5 ++- src/message.rs | 1 + src/server.rs | 112 +++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 109 insertions(+), 9 deletions(-) diff --git a/src/client.rs b/src/client.rs index d30ae68..9bcc479 100644 --- a/src/client.rs +++ b/src/client.rs @@ -21,7 +21,10 @@ pub async fn send(paths: &Vec) -> Result<(), Box let mut buf = BytesMut::with_capacity(1024); buf.put(&b[..]); let body = buf.freeze(); - let m = Message { body: body }; + let m = Message { + from_sender: true, + body: body, + }; stream.send(m).await.unwrap(); } Ok(()) diff --git a/src/message.rs b/src/message.rs index 91636c5..d38a803 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,5 +3,6 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Message { + pub from_sender: bool, pub body: Bytes, } diff --git a/src/server.rs b/src/server.rs index 0ee2f9a..455f53b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,17 +1,83 @@ use crate::message::Message; use futures::prelude::*; +use std::collections::HashMap; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; -use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; +use tokio::sync::{mpsc, Mutex}; +use tokio_serde::SymmetricallyFramed; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; + +type Tx = mpsc::UnboundedSender; +type Rx = mpsc::UnboundedReceiver; + +type MessageStream = SymmetricallyFramed< + Framed, + Message, + tokio_serde::formats::SymmetricalBincode, +>; + +pub struct Shared { + rooms: HashMap, +} +type State = Arc>; + +struct RoomInfo { + sender_tx: Tx, +} + +struct Client { + is_sender: bool, + messages: MessageStream, + rx: Rx, +} + +impl Shared { + fn new() -> Self { + Shared { + rooms: HashMap::new(), + } + } + + // async fn broadcast(&mut self, sender: SocketAddr, message: Message) { + // for peer in self.peers.iter_mut() { + // if *peer.0 != sender { + // let _ = peer.1.send(message.clone()); + // } + // } + // } +} + +impl Client { + async fn new(is_sender: bool, state: State, messages: MessageStream) -> io::Result { + let (tx, rx) = mpsc::unbounded_channel(); + let room_info = RoomInfo { sender_tx: tx }; + state + .lock() + .await + .rooms + .insert("abc".to_string(), room_info); + + Ok(Client { + is_sender, + messages, + rx, + }) + } +} pub async fn serve() -> Result<(), Box> { let addr = "127.0.0.1:8080".to_string(); - let server = TcpListener::bind(&addr).await?; + let listener = TcpListener::bind(&addr).await?; + let state = Arc::new(Mutex::new(Shared::new())); println!("Listening on: {}", addr); loop { - let (stream, _) = server.accept().await?; + let (stream, address) = listener.accept().await?; + let state = Arc::clone(&state); tokio::spawn(async move { - match process(stream).await { + match handle_connection(state, stream, address).await { Ok(_) => println!("ok"), Err(_) => println!("err"), } @@ -19,14 +85,44 @@ pub async fn serve() -> Result<(), Box> { } } -pub async fn process(socket: TcpStream) -> Result<(), Box> { - let length_delimited = FramedRead::new(socket, LengthDelimitedCodec::new()); +pub async fn handle_connection( + state: Arc>, + socket: TcpStream, + addr: SocketAddr, +) -> Result<(), Box> { + let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); let mut stream = tokio_serde::SymmetricallyFramed::new( length_delimited, tokio_serde::formats::SymmetricalBincode::::default(), ); - while let Some(message) = stream.try_next().await? { - println!("GOT: {:?}", message); + let first_message = match stream.next().await { + Some(Ok(msg)) => { + println!("first msg: {:?}", msg); + msg + } + _ => { + println!("no first message"); + return Ok(()); + } + }; + let mut client = Client::new(first_message.from_sender, state.clone(), stream).await?; + // add client to state here + loop { + tokio::select! { + Some(msg) = client.rx.recv() => { + println!("message received to client.rx {:?}", msg); + } + result = client.messages.next() => match result { + Some(Ok(msg)) => { + println!("GOT: {:?}", msg); + } + Some(Err(e)) => { + println!("Error {:?}", e); + } + None => break, + } + } } + // client is disconnected, let's remove them from the state Ok(()) }