diff --git a/bonknet_broker/src/certutils.rs b/bonknet_broker/src/certutils.rs new file mode 100644 index 0000000..acb2799 --- /dev/null +++ b/bonknet_broker/src/certutils.rs @@ -0,0 +1,20 @@ +use rcgen::{Certificate, CertificateParams, DnType}; + +pub struct ServerCert { + pub cert: Vec, + pub prkey: Vec, +} + +pub fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert { + let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]); + params.distinguished_name.push(DnType::CommonName, name); + params.use_authority_key_identifier_extension = true; + params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature); + params + .extended_key_usages + .push(rcgen::ExtendedKeyUsagePurpose::ClientAuth); + let certificate = Certificate::from_params(params).unwrap(); + let cert = certificate.serialize_der_with_signer(root_cert).unwrap(); + let prkey = certificate.serialize_private_key_der(); + ServerCert { cert, prkey } +} diff --git a/bonknet_broker/src/dataconnmanager.rs b/bonknet_broker/src/dataconnmanager.rs index 7737add..81111cb 100644 --- a/bonknet_broker/src/dataconnmanager.rs +++ b/bonknet_broker/src/dataconnmanager.rs @@ -5,7 +5,7 @@ use futures::SinkExt; use thiserror::Error; use tracing::{info, error, warn}; use libbonknet::ToPeerDataStream; -use crate::TransportStream; +use crate::streamutils::*; #[allow(dead_code)] #[derive(Error, Debug)] diff --git a/bonknet_broker/src/main.rs b/bonknet_broker/src/main.rs index 3d351b8..cd1fe53 100644 --- a/bonknet_broker/src/main.rs +++ b/bonknet_broker/src/main.rs @@ -2,11 +2,14 @@ mod servercertdb; mod pendingdataconndb; mod servermanager; mod dataconnmanager; +mod streamutils; +mod certutils; use servercertdb::*; use pendingdataconndb::*; use servermanager::*; use dataconnmanager::*; +use streamutils::*; use actix::prelude::*; use std::sync::Arc; use libbonknet::*; @@ -19,32 +22,20 @@ use actix_server::Server; use actix_rt::net::TcpStream; use actix_service::ServiceFactoryExt as _; use futures::{SinkExt, StreamExt}; -use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing::{error, info, warn}; -use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; -use tokio::io::{ReadHalf, WriteHalf}; +use rcgen::{Certificate, CertificateParams, KeyPair}; +use rustls::pki_types::CertificateDer; -type TransportStream = Framed, LengthDelimitedCodec>; -type TransportStreamTx = FramedWrite>, LengthDelimitedCodec>; -type TransportStreamRx = FramedRead>, LengthDelimitedCodec>; -struct ServerCert { - cert: Vec, - prkey: Vec, -} - -fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert { - let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]); - params.distinguished_name.push(DnType::CommonName, name); - params.use_authority_key_identifier_extension = true; - params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature); - params - .extended_key_usages - .push(rcgen::ExtendedKeyUsagePurpose::ClientAuth); - let certificate = Certificate::from_params(params).unwrap(); - let cert = certificate.serialize_der_with_signer(root_cert).unwrap(); - let prkey = certificate.serialize_private_key_der(); - ServerCert { cert, prkey } +#[derive(Clone)] +struct BrokerContext { + server_root_cert_der: CertificateDer<'static>, + client_root_cert_der: CertificateDer<'static>, + guestserver_root_cert_der: CertificateDer<'static>, + scdb_addr: Addr, + pdcm_addr: Addr, + sm_addr: Addr, } @@ -72,53 +63,47 @@ async fn main() { let server_verifier = WebPkiClientVerifier::builder(Arc::new(broker_root_store)).build().unwrap(); // Configure TLS let server_tlsconfig = ServerConfig::builder() - // .with_no_client_auth() .with_client_cert_verifier(server_verifier) .with_single_cert(vec![broker_cert_der.clone(), broker_root_cert_der.clone()], broker_prkey_der.into()) .unwrap(); let server_acceptor = RustlsAcceptor::new(server_tlsconfig); - let server_root_cert_der = Arc::new(server_root_cert_der); + let scdb_addr = ServerCertDB::new().start(); + let dcm_addr = DataConnManager::new().start(); + let pdcm_addr = PendingDataConnManager::new(dcm_addr).start(); + let sm_addr = ServerManager::new(pdcm_addr.clone()).start(); + let server_root_prkey = KeyPair::from_der(server_root_prkey_der.secret_pkcs8_der()).unwrap(); - let client_root_cert_der = Arc::new(client_root_cert_der); - let guestserver_root_cert_der = Arc::new(guestserver_root_cert_der); let server_root_cert = Arc::new(Certificate::from_params(CertificateParams::from_ca_cert_der( &server_root_cert_der, server_root_prkey ).unwrap()).unwrap()); - let scdb_addr = ServerCertDB::new().start(); - let dcm_addr = DataConnManager::new().start(); - let pdcm_addr = PendingDataConnManager::new(dcm_addr.clone()).start(); - let sm_addr = ServerManager::new(pdcm_addr.clone()).start(); + let ctx = Arc::new(BrokerContext { + server_root_cert_der, + client_root_cert_der, + guestserver_root_cert_der, + scdb_addr, + pdcm_addr, + sm_addr, + }); Server::build() .bind("server-command", ("localhost", 2541), move || { - let server_root_cert_der = Arc::clone(&server_root_cert_der); - let client_root_cert_der = Arc::clone(&client_root_cert_der); - let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); + let ctx = Arc::clone(&ctx); let server_root_cert = Arc::clone(&server_root_cert); - let scdb_addr = scdb_addr.clone(); - let pdcm_addr = pdcm_addr.clone(); - let sm_addr = sm_addr.clone(); - // Set up TLS service factory server_acceptor .clone() .map_err(|err| println!("Rustls error: {:?}", err)) .and_then(move |stream: TlsStream| { - let server_root_cert_der = Arc::clone(&server_root_cert_der); - let client_root_cert_der = Arc::clone(&client_root_cert_der); - let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); + let ctx = Arc::clone(&ctx); let server_root_cert = Arc::clone(&server_root_cert); - let scdb_addr = scdb_addr.clone(); - let pdcm_addr = pdcm_addr.clone(); - let sm_addr = sm_addr.clone(); async move { let peer_certs = stream.get_ref().1.peer_certificates().unwrap(); let peer_cert_bytes = peer_certs.first().unwrap().to_vec(); let peer_root_cert_der = peer_certs.last().unwrap().clone(); - if peer_root_cert_der == *server_root_cert_der { + if peer_root_cert_der == ctx.server_root_cert_der { info!("Server connection"); let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); match transport.next().await { @@ -127,8 +112,7 @@ async fn main() { } Some(item) => match item { Ok(buf) => { - use libbonknet::servermsg::{FromServerConnTypeMessage, ToServerConnTypeReply}; - use libbonknet::servermsg::FromServerConnTypeMessage::*; + use FromServerConnTypeMessage::*; let msg: FromServerConnTypeMessage = rmp_serde::from_slice(&buf).unwrap(); info!("{:?}", msg); match msg { @@ -136,11 +120,11 @@ async fn main() { info!("SendCommand Stream"); let reply = ToServerConnTypeReply::OkSendCommand; transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); - server_command_handler(transport, peer_cert_bytes, scdb_addr).await; + server_command_handler(&ctx, transport, peer_cert_bytes).await; } Subscribe => { info!("Subscribe Stream"); - let name = match scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() { + let name = match ctx.scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() { None => { error!("Cert has no name assigned!"); let reply = ToServerConnTypeReply::GenericFailure; @@ -151,12 +135,12 @@ async fn main() { }; let reply = ToServerConnTypeReply::OkSubscribe; transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); - server_subscribe_handler(transport, name, sm_addr).await; + server_subscribe_handler(&ctx, transport, name).await; } OpenDataStream(conn_id) => { info!("OpenDataStream with {:?}", conn_id); let msg = RegisterStream::server(conn_id, transport); - pdcm_addr.send(msg).await.unwrap().unwrap(); + ctx.pdcm_addr.send(msg).await.unwrap().unwrap(); } } } @@ -166,17 +150,16 @@ async fn main() { } } info!("Server Task terminated!"); - } else if peer_root_cert_der == *guestserver_root_cert_der { + } else if peer_root_cert_der == ctx.guestserver_root_cert_der { info!("GuestServer connection"); - let server_root_cert = Arc::clone(&server_root_cert); let codec = LengthDelimitedCodec::new(); let transport = Framed::new(stream, codec); - guestserver_handler(transport, scdb_addr, &server_root_cert).await; - } else if peer_root_cert_der == *client_root_cert_der { + guestserver_handler(&ctx, transport, &server_root_cert).await; + } else if peer_root_cert_der == ctx.client_root_cert_der { info!("Client connection"); let codec = LengthDelimitedCodec::new(); let transport = Framed::new(stream, codec); - client_handler(transport, sm_addr, pdcm_addr).await; + client_handler(&ctx, transport).await; } else { error!("Unknown Root Certificate"); } @@ -190,8 +173,8 @@ async fn main() { .unwrap(); } -async fn server_subscribe_handler(transport: TransportStream, name: String, sm_addr: Addr) { - match sm_addr.send(StartTransporter { name, transport }).await.unwrap() { +async fn server_subscribe_handler(ctx: &BrokerContext, transport: TransportStream, name: String) { + match ctx.sm_addr.send(StartTransporter { name, transport }).await.unwrap() { Ok(_) => { info!("Stream sent to the manager"); } @@ -201,7 +184,7 @@ async fn server_subscribe_handler(transport: TransportStream, name: String, sm_a } } -async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec, server_db_addr: Addr) { +async fn server_command_handler(ctx: &BrokerContext, mut transport: TransportStream, peer_cert_bytes: Vec) { loop { match transport.next().await { None => { @@ -216,7 +199,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: match msg { ChangeName { name } => { info!("Changing name to {}", name); - match server_db_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() { + match ctx.scdb_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() { None => { info!("Nothing to unregister"); } @@ -224,7 +207,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: warn!("Unregistered from old name {}", old_name); } } - let reply = match server_db_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() { + let reply = match ctx.scdb_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() { Ok(_) => { info!("Registered!"); ToServerCommandReply::NameChanged @@ -238,7 +221,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: } WhoAmI => { info!("Asked who I am"); - let reply = match server_db_addr.send(FetchName { cert: peer_cert_bytes.clone() }).await.unwrap() { + let reply = match ctx.scdb_addr.send(FetchName { cert: peer_cert_bytes.clone() }).await.unwrap() { None => { info!("I'm not registered anymore!? WTF"); ToServerCommandReply::GenericFailure @@ -261,8 +244,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: } } -// TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler -async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Addr, server_root_cert: &Arc) { +async fn guestserver_handler(ctx: &BrokerContext, mut transport: TransportStream, server_root_cert: &Certificate) { loop { match transport.next().await { None => { @@ -278,14 +260,14 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Add match msg { Announce { name } => { info!("Announced with name {}", name); - if server_db_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() { + if ctx.scdb_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() { info!("Name {} already registered!", name); let reply = ToGuestServerMessage::FailedNameAlreadyOccupied; transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); break; // Stop reading } else { - let cert = generate_server_cert(server_root_cert, name.as_str()); - server_db_addr.send(RegisterServer { + let cert = certutils::generate_server_cert(server_root_cert, name.as_str()); + ctx.scdb_addr.send(RegisterServer { cert: cert.cert.clone(), name, }).await.unwrap().unwrap(); @@ -308,7 +290,7 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Add } } -async fn client_handler(mut transport: TransportStream, sm_addr: Addr, pdcm_addr: Addr) { +async fn client_handler(ctx: &BrokerContext, mut transport: TransportStream) { loop { match transport.next().await { None => { @@ -323,7 +305,7 @@ async fn client_handler(mut transport: TransportStream, sm_addr: Addr { info!("REQUESTED SERVER {}", name); - let data = sm_addr.send(RequestServer { name }).await.unwrap(); + let data = ctx.sm_addr.send(RequestServer { name }).await.unwrap(); match data { Ok(client_conn_id) => { let reply = ToClientResponse::OkRequest { conn_id: client_conn_id }; @@ -337,14 +319,14 @@ async fn client_handler(mut transport: TransportStream, sm_addr: Addr { info!("Requested ServerList"); - let data = sm_addr.send(GetServerList {}).await.unwrap(); + let data = ctx.sm_addr.send(GetServerList {}).await.unwrap(); let reply = ToClientResponse::OkServerList { data }; transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); } FromClientCommand::UpgradeToDataStream(conn_id) => { info!("Upgrade to DataStream with conn_id {:?}", conn_id); let msg = RegisterStream::client(conn_id, transport); - pdcm_addr.send(msg).await.unwrap().unwrap(); + ctx.pdcm_addr.send(msg).await.unwrap().unwrap(); break; } } diff --git a/bonknet_broker/src/pendingdataconndb.rs b/bonknet_broker/src/pendingdataconndb.rs index 3ca2b62..152efad 100644 --- a/bonknet_broker/src/pendingdataconndb.rs +++ b/bonknet_broker/src/pendingdataconndb.rs @@ -1,15 +1,11 @@ use actix::prelude::*; -use actix_tls::accept::rustls_0_22::TlsStream; use futures::SinkExt; use thiserror::Error; -use tokio::net::TcpStream; -use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing::{error, info}; use uuid::Uuid; use libbonknet::*; use crate::dataconnmanager::{DataConnManager, StartDataBridge}; - -type TransportStream = Framed, LengthDelimitedCodec>; +use crate::streamutils::*; /* L'idea e' che il database deve avere una riga per ogni connessione dati in nascita. diff --git a/bonknet_broker/src/servermanager.rs b/bonknet_broker/src/servermanager.rs index 964f544..c9c514c 100644 --- a/bonknet_broker/src/servermanager.rs +++ b/bonknet_broker/src/servermanager.rs @@ -10,7 +10,7 @@ use tokio::sync::{Mutex, oneshot}; use tokio_util::bytes::{Bytes, BytesMut}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; use tracing::{debug, error, info}; -use crate::{TransportStream, TransportStreamRx, TransportStreamTx}; +use crate::streamutils::*; use uuid::Uuid; use libbonknet::servermsg::*; use crate::pendingdataconndb::*; diff --git a/bonknet_broker/src/streamutils.rs b/bonknet_broker/src/streamutils.rs new file mode 100644 index 0000000..487e710 --- /dev/null +++ b/bonknet_broker/src/streamutils.rs @@ -0,0 +1,8 @@ +use actix_tls::accept::rustls_0_22::TlsStream; +use tokio::io::{ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; + +pub type TransportStream = Framed, LengthDelimitedCodec>; +pub type TransportStreamTx = FramedWrite>, LengthDelimitedCodec>; +pub type TransportStreamRx = FramedRead>, LengthDelimitedCodec>; \ No newline at end of file