From 83c7a954144dafae7e2cec7db7007980de866b90 Mon Sep 17 00:00:00 2001 From: "Federico Pasqua (eisterman)" Date: Wed, 21 Feb 2024 16:40:49 +0100 Subject: [PATCH] Implement opening of the DataStream. Just the broker copy task/manager is missing --- bonknet_broker/src/dataconnmanager.rs | 2 + bonknet_broker/src/main.rs | 22 ++-- bonknet_broker/src/pendingdataconndb.rs | 135 +++++++++++++++++++----- bonknet_broker/src/servermanager.rs | 3 +- bonknet_client/Cargo.toml | 1 + bonknet_client/src/bin/client.rs | 102 ++++++++++++++++-- bonknet_server/Cargo.toml | 3 +- bonknet_server/src/bin/server.rs | 93 ++++++++++++++-- libbonknet/src/lib.rs | 23 +++- 9 files changed, 328 insertions(+), 56 deletions(-) create mode 100644 bonknet_broker/src/dataconnmanager.rs diff --git a/bonknet_broker/src/dataconnmanager.rs b/bonknet_broker/src/dataconnmanager.rs new file mode 100644 index 0000000..0071c7e --- /dev/null +++ b/bonknet_broker/src/dataconnmanager.rs @@ -0,0 +1,2 @@ +use actix::prelude::*; + diff --git a/bonknet_broker/src/main.rs b/bonknet_broker/src/main.rs index 5c1bb06..58428e2 100644 --- a/bonknet_broker/src/main.rs +++ b/bonknet_broker/src/main.rs @@ -1,14 +1,13 @@ mod servercertdb; mod pendingdataconndb; mod servermanager; +mod dataconnmanager; use servercertdb::*; use pendingdataconndb::*; use servermanager::*; use actix::prelude::*; -use std::collections::HashMap; -use std::sync::{Arc}; -use std::time::{Instant, Duration}; +use std::sync::Arc; use libbonknet::*; use rustls::{RootCertStore, ServerConfig}; use rustls::server::WebPkiClientVerifier; @@ -17,16 +16,10 @@ use actix_server::Server; use actix_rt::net::TcpStream; use actix_service::{ServiceFactoryExt as _}; use futures::{StreamExt, SinkExt}; -use rand::random; use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; use tracing::{info, error, warn}; use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; -use thiserror::Error; use tokio::io::{ReadHalf, WriteHalf}; -use tokio_util::bytes::{Bytes, BytesMut}; -use tokio::io::Error; -use tokio::sync::{oneshot, Mutex}; -use uuid::Uuid; type TransportStream = Framed, LengthDelimitedCodec>; type TransportStreamTx = FramedWrite>, LengthDelimitedCodec>; @@ -160,7 +153,8 @@ async fn main() { } OpenDataStream(conn_id) => { info!("OpenDataStream with {:?}", conn_id); - // TODO: OpenDataStream + let msg = RegisterStream::server(conn_id, transport); + pdcm_addr.send(msg).await.unwrap().unwrap(); } } } @@ -180,7 +174,7 @@ async fn main() { info!("Client connection"); let codec = LengthDelimitedCodec::new(); let transport = Framed::new(stream, codec); - client_handler(transport, sm_addr).await; + client_handler(transport, sm_addr, pdcm_addr).await; } else { error!("Unknown Root Certificate"); } @@ -312,7 +306,7 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Add } } -async fn client_handler(mut transport: TransportStream, sm_addr: Addr) { +async fn client_handler(mut transport: TransportStream, sm_addr: Addr, pdcm_addr: Addr) { loop { match transport.next().await { None => { @@ -347,7 +341,9 @@ async fn client_handler(mut transport: TransportStream, sm_addr: Addr { info!("Upgrade to DataStream with conn_id {:?}", conn_id); - // TODO: Upgrade to DataStream + let msg = RegisterStream::client(conn_id, transport); + pdcm_addr.send(msg).await.unwrap().unwrap(); + break; } } } diff --git a/bonknet_broker/src/pendingdataconndb.rs b/bonknet_broker/src/pendingdataconndb.rs index 592844a..9158783 100644 --- a/bonknet_broker/src/pendingdataconndb.rs +++ b/bonknet_broker/src/pendingdataconndb.rs @@ -1,11 +1,12 @@ use actix::prelude::*; use actix_tls::accept::rustls_0_22::TlsStream; +use futures::SinkExt; use thiserror::Error; -use tokio::io::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; -use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing::{error, info}; use uuid::Uuid; +use libbonknet::*; type TransportStream = Framed, LengthDelimitedCodec>; @@ -73,8 +74,8 @@ struct Record { impl Record { fn new(server_conn_id: Uuid, client_conn_id: Uuid) -> Self { - let server = SideRecord { conn_id: client_conn_id, transport: None }; - let client = SideRecord { conn_id: server_conn_id, transport: None }; + let server = SideRecord { conn_id: server_conn_id, transport: None }; + let client = SideRecord { conn_id: client_conn_id, transport: None }; Record { server, client } } } @@ -88,6 +89,21 @@ impl PendingDataConnManager { pub fn new() -> Self { PendingDataConnManager { db: Vec::new() } } + + fn retrieve_siderecord(&mut self, kind: &RegisterKind, conn_id: &Uuid) -> Option<&mut SideRecord> { + use RegisterKind::*; + let record = match match kind { + Server => self.db.iter_mut().find(|x| x.server.conn_id == *conn_id), + Client => self.db.iter_mut().find(|x| x.client.conn_id == *conn_id), + } { + None => return None, + Some(item) => item, + }; + Some(match kind { + Server => &mut record.server, + Client => &mut record.client, + }) + } } impl Actor for PendingDataConnManager { @@ -113,37 +129,104 @@ impl Handler for PendingDataConnManager { } } -impl Handler for PendingDataConnManager { +/* +Esegui tutti i test normali in Sync. +Quando devi inviare il Accepted, notificati la cosa come AcceptStream con lo stream in suo possesso +ma stavolta con un ResponseActFuture. +Se c'e un fallimento, sposta il transport in un ctx::spawn che invii un FAILED. +Se tutto OK, checka DI NUOVO tutto, e se i check sono positivi, registra lo stream nell'Actor. + +Per gestire lo Spawn di una connection, l'unica risposta e' gestire lo spawn connection come un +Message Handler a sua volta. Quando uno stream completa l'invio del suo Accepted, esso appare nel Record. +Quando il secondo stream arriva e completa il suo accepted, anch'esso viene registrato nel Record, quindi +siamo nella condizione di spawn, perche ci sono entrambi i transport nel Record. +Quindi se alla fine di un check and register ci sono entrambi gli stream, spostali entrambi fuori, droppa +il record e invia i due transport ad un terzo actor, su un altro Arbiter, che esegua il tokio::io::copy e +gestisca le connessioni aperte. + */ + +#[derive(Message)] +#[rtype(result = "Result<(),PendingDataConnError>")] +struct TryStartDataStream { + kind: RegisterKind, + conn_id: Uuid, +} + +impl Handler for PendingDataConnManager { type Result = Result<(), PendingDataConnError>; - fn handle(&mut self, msg: RegisterStream, _ctx: &mut Self::Context) -> Self::Result { + fn handle(&mut self, msg: TryStartDataStream, _ctx: &mut Self::Context) -> Self::Result { use RegisterKind::*; - let conn_id = msg.conn_id; - let record = match match msg.kind { - Server => self.db.iter_mut().find(|x| x.server.conn_id == conn_id), - Client => self.db.iter_mut().find(|x| x.client.conn_id == conn_id), - } { - None => { - error!("Found no connection with {:?} conn_id {:?}", msg.kind, conn_id); - return Err(PendingDataConnError::GenericFailure); - }, - Some(item) => item, + let idx = match msg.kind { + Server => self.db.iter().enumerate().find(|(i, x)| x.server.conn_id == msg.conn_id), + Client => self.db.iter().enumerate().find(|(i, x)| x.client.conn_id == msg.conn_id), }; - let side_record = match msg.kind { - Server => &mut record.server, - Client => &mut record.client, + if let Some((_idx, record)) = idx { + if record.client.transport.is_some() && record.server.transport.is_some() { + // TODO: Launch the "thing" that will manage the data mirroring + info!("LAUNCHING DATA MIRRORING"); + // TODO: we can drop record and use idx to remove the record itself from the vector + // and then send it to another manager for the spawn of the real connection + } + Ok(()) + } else { + Err(PendingDataConnError::GenericFailure) + } + } +} + +impl Handler for PendingDataConnManager { + type Result = ResponseActFuture>; + + fn handle(&mut self, msg: RegisterStream, _ctx: &mut Self::Context) -> Self::Result { + let side_record = match self.retrieve_siderecord(&msg.kind, &msg.conn_id) { + None => { + error!("Found no connection with {:?} conn_id {:?}", msg.kind, msg.conn_id); + return Box::pin(fut::err(PendingDataConnError::GenericFailure)); + } + Some(item) => item, }; if side_record.transport.is_some() { // TODO: It can be good to check if the connection is still open, if not, drop and use the new one. error!("Connection already with a socket!"); - return Err(PendingDataConnError::GenericFailure); + Box::pin(fut::err(PendingDataConnError::GenericFailure)) } else { - side_record.transport = Some(msg.transport); + // This Fut will send the Accepted and only then, register the transport stream + // in the Manager. If during the registration there are all the transport in places, + // you can start the datapiping + Box::pin(async move { + let mut transport = msg.transport; + let reply = ToPeerDataStream::OkDataStreamRequestAccepted; + let res = transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await; + (transport, res) + }.into_actor(self).map(move |(transport, res), a, c| { + match res { + Ok(_) => { + // TODO: to not do the check twice I can put a "lock variable" inside the record + // to prevent the put in Accept of another stream while we are waiting to send the + // accept message + let side_record = match a.retrieve_siderecord(&msg.kind, &msg.conn_id) { + None => { + error!("Found no connection with {:?} conn_id {:?}", msg.kind, msg.conn_id); + return Err(PendingDataConnError::GenericFailure); + } + Some(item) => item, + }; + if side_record.transport.is_some() { + // TODO: It can be good to check if the connection is still open, if not, drop and use the new one. + error!("Connection already with a socket!"); + return Err(PendingDataConnError::GenericFailure); + } + side_record.transport = Some(transport); + c.notify(TryStartDataStream { kind: msg.kind, conn_id: msg.conn_id }); + Ok(()) + } + Err(e) => { + error!("Error during OkDataStreamRequestAccepted sending: {:?}", e); + Err(PendingDataConnError::GenericFailure) + } + } + })) } - if record.client.transport.is_some() && record.server.transport.is_some() { - // TODO: Launch the "thing" that will manage the data mirroring - info!("LAUNCHING DATA MIRRORING"); - } - Ok(()) } } diff --git a/bonknet_broker/src/servermanager.rs b/bonknet_broker/src/servermanager.rs index 5b2d89e..5d4e39c 100644 --- a/bonknet_broker/src/servermanager.rs +++ b/bonknet_broker/src/servermanager.rs @@ -3,10 +3,9 @@ use std::io::Error; use std::sync::{Arc}; use std::time::{Duration, Instant}; use actix::prelude::*; -use actix_server::Server; use rand::random; use thiserror::Error; -use futures::{StreamExt, SinkExt}; +use futures::{SinkExt}; use tokio::sync::{Mutex, oneshot}; use tokio_util::bytes::{Bytes, BytesMut}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; diff --git a/bonknet_client/Cargo.toml b/bonknet_client/Cargo.toml index a67572b..ed30a32 100644 --- a/bonknet_client/Cargo.toml +++ b/bonknet_client/Cargo.toml @@ -16,3 +16,4 @@ rustls-pemfile = "2.0.0" rmp-serde = "1.1.2" tracing = "0.1" tracing-subscriber = "0.3" +uuid = { version = "1.7.0", features = ["serde"] } diff --git a/bonknet_client/src/bin/client.rs b/bonknet_client/src/bin/client.rs index ec892a4..3f9113b 100644 --- a/bonknet_client/src/bin/client.rs +++ b/bonknet_client/src/bin/client.rs @@ -1,17 +1,102 @@ -use std::io::Error; +use std::io::{Error, ErrorKind}; use std::sync::Arc; use std::time::Duration; use futures::{StreamExt, SinkExt}; use tokio::net::TcpStream; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; -use tokio_rustls::rustls::pki_types::{ServerName, CertificateDer, PrivatePkcs8KeyDer}; +use tokio_rustls::rustls::pki_types::{ServerName}; use tokio_rustls::TlsConnector; -use tokio_util::bytes::BytesMut; use tokio_util::codec::{Framed, LengthDelimitedCodec}; use libbonknet::*; +use uuid::Uuid; use tracing::{info, error}; +async fn datastream(tlsconfig: ClientConfig, conn_id: Uuid) -> std::io::Result<()> { + let connector = TlsConnector::from(Arc::new(tlsconfig.clone())); + let dnsname = ServerName::try_from("localhost").unwrap(); + let stream = TcpStream::connect("localhost:2541").await?; + let stream = connector.connect(dnsname, stream).await?; + let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); + + let msg = FromClientCommand::UpgradeToDataStream(conn_id); + transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); + match transport.next().await { + None => panic!("None in the transport"), + Some(item) => match item { + Ok(buf) => { + use ToPeerDataStream::*; + let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap(); + match msg { + OkDataStreamRequestAccepted => { + info!("Data Stream Accepted. Waiting for Open..."); + } + Refused => { + error!("Refused"); + return Err(Error::new(ErrorKind::ConnectionRefused, "Refused")); + } + other => { + error!("Unexpected response: {:?}", other); + return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response")); + } + } + } + Err(e) => { + error!("Error: {:?}", e); + return Err(e); + } + } + } + match transport.next().await { + None => panic!("None in the transport"), + Some(item) => match item { + Ok(buf) => { + use ToPeerDataStream::*; + let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap(); + match msg { + OkDataStreamOpen => { + info!("Data Stream Open!. Connecting Streams."); + } + Revoked => { + error!("Data Stream Revoked!"); + return Err(Error::new(ErrorKind::ConnectionAborted, "Revoked")); + } + Refused => { + error!("Refused"); + return Err(Error::new(ErrorKind::ConnectionRefused, "Refused")); + } + other => { + error!("Unexpected response: {:?}", other); + return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response")); + } + } + } + Err(e) => { + error!("Error: {:?}", e); + return Err(e); + } + } + } + let (mut rx, mut tx) = tokio::io::split(transport.into_inner()); + let mut stdout = tokio::io::stdout(); + let mut stdin = tokio::io::stdin(); + let stdout_task = async move { + match tokio::io::copy(&mut rx, &mut stdout).await { + Ok(bytes_copied) => info!("{bytes_copied}"), + Err(e) => error!("Error during copy: {e}"), + } + }; + let stdin_task = async move { + match tokio::io::copy(&mut stdin, &mut tx).await { + Ok(bytes_copied) => info!("{bytes_copied}"), + Err(e) => error!("Error during copy: {e}"), + } + }; + tokio::join!(stdout_task, stdin_task); + Ok(()) +} + + #[tokio::main] async fn main() -> std::io::Result<()> { // Tracing Subscriber @@ -31,7 +116,7 @@ async fn main() -> std::io::Result<()> { .with_root_certificates(broker_root_cert_store.clone()) .with_client_auth_cert(vec![client_cert, root_client_cert], client_cert_prkey.into()) .unwrap(); - let connector = TlsConnector::from(Arc::new(tlsconfig)); + let connector = TlsConnector::from(Arc::new(tlsconfig.clone())); let dnsname = ServerName::try_from("localhost").unwrap(); let stream = TcpStream::connect("localhost:2541").await?; @@ -47,9 +132,11 @@ async fn main() -> std::io::Result<()> { use ToClientResponse::*; let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap(); match msg { - OkRequest { .. } => error!("Wrong reply!"), OkServerList { data } => info!("{}", data.join("\n")), GenericError => error!("Generic Error during remote command execution"), + others => { + panic!("Unexpected Message type: {:?}", others); + } } } Err(e) => { @@ -69,9 +156,12 @@ async fn main() -> std::io::Result<()> { match msg { OkRequest { conn_id } => { info!("Received Client Connection ID: {:?}", conn_id); + datastream(tlsconfig, conn_id).await.unwrap(); } - OkServerList { .. } => error!("Wrong reply!"), GenericError => error!("Generic Error during remote command execution"), + others => { + panic!("Unexpected Message type: {:?}", others); + } } } Err(e) => { diff --git a/bonknet_server/Cargo.toml b/bonknet_server/Cargo.toml index 17e97d0..6365333 100644 --- a/bonknet_server/Cargo.toml +++ b/bonknet_server/Cargo.toml @@ -15,4 +15,5 @@ tokio-rustls = "0.25.0" rustls-pemfile = "2.0.0" rmp-serde = "1.1.2" tracing = "0.1" -tracing-subscriber = "0.3" \ No newline at end of file +tracing-subscriber = "0.3" +uuid = { version = "1.7.0", features = ["serde"] } \ No newline at end of file diff --git a/bonknet_server/src/bin/server.rs b/bonknet_server/src/bin/server.rs index 6ab04d3..06fac18 100644 --- a/bonknet_server/src/bin/server.rs +++ b/bonknet_server/src/bin/server.rs @@ -1,14 +1,91 @@ +use std::io::{Error, ErrorKind}; use std::sync::Arc; use futures::{StreamExt, SinkExt}; use tokio::net::TcpStream; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::rustls::pki_types::{ServerName, CertificateDer, PrivatePkcs8KeyDer}; use tokio_rustls::TlsConnector; +use tokio_util::bytes::BytesMut; use tokio_util::codec::{Framed, LengthDelimitedCodec}; use libbonknet::*; +use uuid::Uuid; use tracing::{info, error}; +use libbonknet::ToPeerDataStream::{OkDataStreamOpen, OkDataStreamRequestAccepted, Refused, Revoked}; +async fn datastream(tlsconfig: Arc, conn_id: Uuid) -> std::io::Result<()> { + let connector = TlsConnector::from(tlsconfig); + let dnsname = ServerName::try_from("localhost").unwrap(); + let stream = TcpStream::connect("localhost:2541").await?; + let stream = connector.connect(dnsname, stream).await?; + let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); + + let msg = FromServerConnTypeMessage::OpenDataStream(conn_id); + transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); + match transport.next().await { + None => panic!("None in the transport"), + Some(item) => match item { + Ok(buf) => { + use ToPeerDataStream::*; + let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap(); + match msg { + OkDataStreamRequestAccepted => { + info!("Data Stream Accepted. Waiting for Open..."); + } + Refused => { + error!("Refused"); + return Err(Error::new(ErrorKind::ConnectionRefused, "Refused")); + } + other => { + error!("Unexpected response: {:?}", other); + return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response")); + } + } + } + Err(e) => { + error!("Error: {:?}", e); + return Err(e); + } + } + } + match transport.next().await { + None => panic!("None in the transport"), + Some(item) => match item { + Ok(buf) => { + use ToPeerDataStream::*; + let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap(); + match msg { + OkDataStreamOpen => { + info!("Data Stream Open!. Connecting Streams."); + } + Revoked => { + error!("Data Stream Revoked!"); + return Err(Error::new(ErrorKind::ConnectionAborted, "Revoked")); + } + Refused => { + error!("Refused"); + return Err(Error::new(ErrorKind::ConnectionRefused, "Refused")); + } + other => { + error!("Unexpected response: {:?}", other); + return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response")); + } + } + } + Err(e) => { + error!("Error: {:?}", e); + return Err(e); + } + } + } + let (mut rx, mut tx) = tokio::io::split(transport.into_inner()); + match tokio::io::copy(&mut rx, &mut tx).await { + Ok(bytes_copied) => info!("{bytes_copied}"), + Err(e) => error!("Error during copy: {e}"), + } + Ok(()) +} + #[tokio::main] async fn main() -> std::io::Result<()> { // Tracing Subscriber @@ -96,12 +173,12 @@ async fn main() -> std::io::Result<()> { OkSendCommand => { info!("Stream set in SendCommand mode"); } - OkSubscribe => { - panic!("Unexpected OkSubscribe"); - } GenericFailure => { panic!("Generic Failure during SendCommand"); } + others => { + panic!("Unexpected Message type: {:?}", others); + } } } Err(e) => { @@ -161,12 +238,12 @@ async fn main() -> std::io::Result<()> { OkSubscribe => { info!("Stream set in Subscribe mode"); } - OkSendCommand => { - panic!("Unexpected OkSendCommand"); - } GenericFailure => { panic!("Generic Failure during SendCommand"); } + others => { + panic!("Unexpected Message type: {:?}", others); + } } } Err(e) => { @@ -195,7 +272,9 @@ async fn main() -> std::io::Result<()> { out = Some(FromServerReply::Msg { reply_id, body: FromServerReplyBody::RequestAccepted, - }) + }); + // TODO: SPAWN DATASTREAM + tokio::spawn(datastream(tlsconfig.clone(), conn_id)); } } } diff --git a/libbonknet/src/lib.rs b/libbonknet/src/lib.rs index c0977a5..1d935c9 100644 --- a/libbonknet/src/lib.rs +++ b/libbonknet/src/lib.rs @@ -6,7 +6,7 @@ use uuid::Uuid; pub fn load_cert(filename: &str) -> std::io::Result { let cert_file = std::fs::File::open(filename).unwrap(); - let mut buf = std::io::BufReader::new(cert_file); + let mut buf = BufReader::new(cert_file); if let Item::X509Certificate(cert) = read_one(&mut buf).unwrap().unwrap() { Ok(cert) } else { @@ -30,12 +30,17 @@ pub fn load_prkey(filename: &str) -> std::io::Result { pub enum FromServerConnTypeMessage { SendCommand, Subscribe, + OpenDataStream(Uuid), } #[derive(Debug, Serialize, Deserialize)] pub enum ToServerConnTypeReply { OkSendCommand, OkSubscribe, + // You are now a DataStream, wait the Open message + OkDataStreamRequestAccepted, + // The stream is open, you can pipe in-out the content you want! + OkDataStreamOpen, GenericFailure, } @@ -120,11 +125,27 @@ impl ToGuestServerMessage { pub enum FromClientCommand { RequestServer { name: String }, ServerList, + UpgradeToDataStream(Uuid), } #[derive(Debug, Serialize, Deserialize)] pub enum ToClientResponse { OkRequest { conn_id: Uuid }, OkServerList { data: Vec }, + // You are now a DataStream, wait the Open message + OkDataStreamRequestAccepted, + // The stream is open, you can pipe in-out the content you want! + OkDataStreamOpen, + GenericError, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ToPeerDataStream { + // You are now a DataStream, wait the Open message + OkDataStreamRequestAccepted, + // The stream is open, you can pipe in-out the content you want! + OkDataStreamOpen, + Refused, + Revoked, GenericError, }