mirror of
https://github.com/CompeyDev/ruck.git
synced 2025-01-08 11:49:09 +00:00
Simplify connection stapling
This commit is contained in:
parent
7befc61ab3
commit
e701d69946
4 changed files with 52 additions and 102 deletions
|
@ -37,7 +37,6 @@ pub async fn send(file_paths: &Vec<PathBuf>, password: &String) -> Result<()> {
|
||||||
// Complete handshake, returning cipher used for encryption
|
// Complete handshake, returning cipher used for encryption
|
||||||
let (stream, cipher) = handshake(
|
let (stream, cipher) = handshake(
|
||||||
&mut stream,
|
&mut stream,
|
||||||
true,
|
|
||||||
Bytes::from(password.to_string()),
|
Bytes::from(password.to_string()),
|
||||||
pass_to_bytes(password),
|
pass_to_bytes(password),
|
||||||
)
|
)
|
||||||
|
@ -57,7 +56,6 @@ pub async fn receive(password: &String) -> Result<()> {
|
||||||
let mut stream = Message::to_stream(socket);
|
let mut stream = Message::to_stream(socket);
|
||||||
let (stream, cipher) = handshake(
|
let (stream, cipher) = handshake(
|
||||||
&mut stream,
|
&mut stream,
|
||||||
false,
|
|
||||||
Bytes::from(password.to_string()),
|
Bytes::from(password.to_string()),
|
||||||
pass_to_bytes(password),
|
pass_to_bytes(password),
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,6 @@ use spake2::{Ed25519Group, Identity, Password, Spake2};
|
||||||
|
|
||||||
pub async fn handshake(
|
pub async fn handshake(
|
||||||
stream: &mut MessageStream,
|
stream: &mut MessageStream,
|
||||||
up: bool,
|
|
||||||
password: Bytes,
|
password: Bytes,
|
||||||
id: Bytes,
|
id: Bytes,
|
||||||
) -> Result<(&mut MessageStream, Aes256Gcm)> {
|
) -> Result<(&mut MessageStream, Aes256Gcm)> {
|
||||||
|
@ -18,7 +17,6 @@ pub async fn handshake(
|
||||||
Spake2::<Ed25519Group>::start_symmetric(&Password::new(password), &Identity::new(&id));
|
Spake2::<Ed25519Group>::start_symmetric(&Password::new(password), &Identity::new(&id));
|
||||||
println!("client - sending handshake msg");
|
println!("client - sending handshake msg");
|
||||||
let handshake_msg = Message::HandshakeMessage(HandshakePayload {
|
let handshake_msg = Message::HandshakeMessage(HandshakePayload {
|
||||||
up,
|
|
||||||
id,
|
id,
|
||||||
msg: Bytes::from(outbound_msg),
|
msg: Bytes::from(outbound_msg),
|
||||||
});
|
});
|
||||||
|
|
|
@ -21,7 +21,6 @@ pub enum Message {
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct HandshakePayload {
|
pub struct HandshakePayload {
|
||||||
pub up: bool,
|
|
||||||
pub id: Bytes,
|
pub id: Bytes,
|
||||||
pub msg: Bytes,
|
pub msg: Bytes,
|
||||||
}
|
}
|
||||||
|
|
149
src/server.rs
149
src/server.rs
|
@ -13,109 +13,74 @@ type Tx = mpsc::UnboundedSender<Message>;
|
||||||
type Rx = mpsc::UnboundedReceiver<Message>;
|
type Rx = mpsc::UnboundedReceiver<Message>;
|
||||||
|
|
||||||
pub struct Shared {
|
pub struct Shared {
|
||||||
handshakes: HashMap<Bytes, Rx>,
|
handshake_cache: HashMap<Bytes, Tx>,
|
||||||
senders: HashMap<Bytes, Tx>,
|
|
||||||
receivers: HashMap<Bytes, Tx>,
|
|
||||||
}
|
}
|
||||||
type State = Arc<Mutex<Shared>>;
|
type State = Arc<Mutex<Shared>>;
|
||||||
|
|
||||||
struct Client<'a> {
|
struct Client {
|
||||||
up: bool,
|
messages: MessageStream,
|
||||||
id: Bytes,
|
|
||||||
messages: &'a mut MessageStream,
|
|
||||||
rx: Rx,
|
rx: Rx,
|
||||||
|
peer_tx: Option<Tx>,
|
||||||
|
}
|
||||||
|
struct StapledClient {
|
||||||
|
messages: MessageStream,
|
||||||
|
rx: Rx,
|
||||||
|
peer_tx: Tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Shared {
|
impl Shared {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Shared {
|
Shared {
|
||||||
handshakes: HashMap::new(),
|
handshake_cache: HashMap::new(),
|
||||||
senders: HashMap::new(),
|
|
||||||
receivers: HashMap::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
async fn relay<'a>(&self, client: &Client<'a>, message: Message) -> Result<()> {
|
|
||||||
println!("in relay - got client={:?}, msg {:?}", client.id, message);
|
|
||||||
match client.up {
|
|
||||||
true => match self.receivers.get(&client.id) {
|
|
||||||
Some(tx) => {
|
|
||||||
tx.send(message)?;
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
return Err(anyhow!(RuckError::PairDisconnected));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
false => match self.senders.get(&client.id) {
|
|
||||||
Some(tx) => {
|
|
||||||
tx.send(message)?;
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
return Err(anyhow!(RuckError::PairDisconnected));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Client<'a> {
|
impl Client {
|
||||||
async fn new(
|
async fn new(id: Bytes, state: State, messages: MessageStream) -> Result<Client> {
|
||||||
up: bool,
|
|
||||||
id: Bytes,
|
|
||||||
state: State,
|
|
||||||
messages: &'a mut MessageStream,
|
|
||||||
) -> Result<Client<'a>> {
|
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
println!("server - creating client up={:?}, id={:?}", up, id);
|
let mut shared = state.lock().await;
|
||||||
let shared = &mut state.lock().await;
|
let client = Client {
|
||||||
match shared.senders.get(&id) {
|
|
||||||
Some(_) if up => {
|
|
||||||
messages
|
|
||||||
.send(Message::ErrorMessage(RuckError::SenderAlreadyConnected))
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
Some(_) => {
|
|
||||||
println!("server - adding client to receivers");
|
|
||||||
shared.receivers.insert(id.clone(), tx);
|
|
||||||
}
|
|
||||||
None if up => {
|
|
||||||
println!("server - adding client to senders");
|
|
||||||
shared.senders.insert(id.clone(), tx);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
messages
|
|
||||||
.send(Message::ErrorMessage(RuckError::SenderNotConnected))
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Client {
|
|
||||||
up,
|
|
||||||
id,
|
|
||||||
messages,
|
messages,
|
||||||
rx,
|
rx,
|
||||||
})
|
peer_tx: shared.handshake_cache.remove(&id),
|
||||||
|
};
|
||||||
|
shared.handshake_cache.insert(id, tx);
|
||||||
|
Ok(client)
|
||||||
}
|
}
|
||||||
async fn complete_handshake(&mut self, state: State, msg: Message) -> Result<()> {
|
|
||||||
match self.up {
|
async fn upgrade(
|
||||||
true => {
|
client: Client,
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
state: State,
|
||||||
tx.send(msg)?;
|
handshake_payload: HandshakePayload,
|
||||||
state.lock().await.handshakes.insert(self.id.clone(), rx);
|
) -> Result<StapledClient> {
|
||||||
}
|
let mut client = client;
|
||||||
false => {
|
let peer_tx = match client.peer_tx {
|
||||||
let shared = &mut state.lock().await;
|
// Receiver - already stapled at creation
|
||||||
if let Some(tx) = shared.senders.get(&self.id) {
|
Some(peer_tx) => peer_tx,
|
||||||
tx.send(msg)?;
|
// Sender - needs to wait for the incoming msg to look up peer_tx
|
||||||
}
|
None => {
|
||||||
if let Some(mut rx) = shared.handshakes.remove(&self.id) {
|
match client.rx.recv().await {
|
||||||
drop(shared);
|
Some(msg) => client.messages.send(msg).await?,
|
||||||
if let Some(msg) = rx.recv().await {
|
None => return Err(anyhow!("Connection not stapled")),
|
||||||
self.messages.send(msg).await?;
|
};
|
||||||
}
|
match state
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.handshake_cache
|
||||||
|
.remove(&handshake_payload.id)
|
||||||
|
{
|
||||||
|
Some(peer_tx) => peer_tx,
|
||||||
|
None => return Err(anyhow!("Connection not stapled")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
Ok(())
|
peer_tx.send(Message::HandshakeMessage(handshake_payload))?;
|
||||||
|
Ok(StapledClient {
|
||||||
|
messages: client.messages,
|
||||||
|
rx: client.rx,
|
||||||
|
peer_tx,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,16 +122,8 @@ pub async fn handle_connection(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("server - received msg from {:?}", addr);
|
println!("server - received msg from {:?}", addr);
|
||||||
let mut client = Client::new(
|
let client = Client::new(handshake_payload.id.clone(), state.clone(), stream).await?;
|
||||||
handshake_payload.up,
|
let mut client = Client::upgrade(client, state.clone(), handshake_payload).await?;
|
||||||
handshake_payload.id.clone(),
|
|
||||||
state.clone(),
|
|
||||||
&mut stream,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
client
|
|
||||||
.complete_handshake(state.clone(), Message::HandshakeMessage(handshake_payload))
|
|
||||||
.await?;
|
|
||||||
// add client to state here
|
// add client to state here
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
|
@ -176,9 +133,7 @@ pub async fn handle_connection(
|
||||||
}
|
}
|
||||||
result = client.messages.next() => match result {
|
result = client.messages.next() => match result {
|
||||||
Some(Ok(msg)) => {
|
Some(Ok(msg)) => {
|
||||||
println!("GOT: {:?}", msg);
|
client.peer_tx.send(msg)?
|
||||||
let state = state.lock().await;
|
|
||||||
state.relay(&client, msg).await?;
|
|
||||||
}
|
}
|
||||||
Some(Err(e)) => {
|
Some(Err(e)) => {
|
||||||
println!("Error {:?}", e);
|
println!("Error {:?}", e);
|
||||||
|
|
Loading…
Reference in a new issue