diff --git a/bonknet_broker/src/main.rs b/bonknet_broker/src/main.rs index b794ab4..fc62b3f 100644 --- a/bonknet_broker/src/main.rs +++ b/bonknet_broker/src/main.rs @@ -1,7 +1,7 @@ use actix::prelude::*; use std::collections::HashMap; use std::sync::Arc; -use libbonknet::{load_cert, load_prkey, FromServerMessage, RequiredReplyValues, FromGuestServerMessage, ToGuestServerMessage}; +use libbonknet::*; use rustls::{RootCertStore, ServerConfig}; use rustls::server::WebPkiClientVerifier; use actix_tls::accept::rustls_0_22::{Acceptor as RustlsAcceptor, TlsStream}; @@ -14,6 +14,8 @@ use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing::{info, error}; use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; +type TransportStream = Framed, LengthDelimitedCodec>; + struct ServerCert { cert: Vec, prkey: Vec, @@ -21,7 +23,7 @@ struct ServerCert { fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert { let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]); - params.distinguished_name.push(DnType::CommonName, format!("{name}")); + params.distinguished_name.push(DnType::CommonName, name); params.use_authority_key_identifier_extension = true; params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature); params @@ -54,6 +56,12 @@ struct RegisterServer { name: String, } +#[derive(Message)] +#[rtype(result = "Option")] +struct FetchName { + cert: Vec, +} + // TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique struct ServerCertDB { db: HashMap, String>, // Cert to Name @@ -87,32 +95,14 @@ impl Handler for ServerCertDB { } } -struct GuestServerConnection { - stream: TlsStream, -} +impl Handler for ServerCertDB { + type Result = Option; -impl Actor for GuestServerConnection { - type Context = Context; -} - -struct ServerConnection { - stream: Framed, T>, - name: String -} - -impl ServerConnection { - fn new(stream: TlsStream, codec: T) -> Self { - let stream = Framed::new(stream, codec); - ServerConnection { - stream, - name: "Polnareffland1".into(), - } + fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result { + self.db.get(&msg.cert).map(|s| s.to_owned()) } } -impl Actor for ServerConnection { - type Context = Context; -} #[actix_rt::main] async fn main() { @@ -149,7 +139,7 @@ async fn main() { let client_root_cert_der = Arc::new(client_root_cert_der); let guestserver_root_cert_der = Arc::new(guestserver_root_cert_der); let server_root_cert = Arc::new(Certificate::from_params(CertificateParams::from_ca_cert_der( - &*server_root_cert_der, + &server_root_cert_der, server_root_prkey ).unwrap()).unwrap()); @@ -176,87 +166,46 @@ async fn main() { let server_root_cert = Arc::clone(&server_root_cert); let server_db_addr = server_db_addr.clone(); async move { - let peer_cert_der = stream.get_ref().1.peer_certificates().unwrap().last().unwrap().clone(); - if peer_cert_der == *server_root_cert_der { + let peer_certs = stream.get_ref().1.peer_certificates().unwrap(); + let peer_cert_bytes = peer_certs.first().unwrap().to_vec(); + let peer_root_cert_der = peer_certs.last().unwrap().clone(); + if peer_root_cert_der == *server_root_cert_der { info!("Server connection"); - let framed = Framed::new(stream, LengthDelimitedCodec::new()); - framed.for_each(|item| async move { - match item { + let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); + match transport.next().await { + None => { + info!("Connection closed by peer"); + } + Some(item) => match item { Ok(buf) => { - use FromServerMessage::*; - let msg: FromServerMessage = rmp_serde::from_slice(&buf).unwrap(); + use FromServerConnTypeMessage::*; + let msg: FromServerConnTypeMessage = rmp_serde::from_slice(&buf).unwrap(); info!("{:?}", msg); match msg { - RequiredReply(v) => match v { - RequiredReplyValues::Ok => { - info!("Required Reply OK") - } - RequiredReplyValues::GenericFailure { .. } => { - info!("Required Reply Generic Failure") - } + SendCommand => { + info!("SendCommand Stream"); + let reply = ToServerConnTypeReply::OkSendCommand; + transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); + server_command_handler(transport, peer_cert_bytes, &server_db_addr).await; } - ChangeName { name } => { - info!("Requested Change Name to Name {}", name); - } - WhoAmI => { - info!("Requested WhoAmI"); + Subscribe => { + info!("Subscribe Stream") } } - }, + } Err(e) => { info!("Disconnection: {:?}", e); } } - }).await; - info!("Disconnection!"); - } else if peer_cert_der == *guestserver_root_cert_der { + } + info!("Server Task terminated!"); + } else if peer_root_cert_der == *guestserver_root_cert_der { info!("GuestServer connection"); let server_root_cert = Arc::clone(&server_root_cert); let codec = LengthDelimitedCodec::new(); - let mut transport = Framed::new(stream, codec); - loop { - match transport.next().await { - None => { - info!("Transport returned None"); - } - Some(item) => { - match item { - Ok(buf) => { - use FromGuestServerMessage::*; - let msg: FromGuestServerMessage = rmp_serde::from_slice(&buf).unwrap(); - info!("{:?}", msg); - match msg { - Announce { name } => { - info!("Announced with name {}", name); - if server_db_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() { - info!("Name {} already registered!", name); - let reply = ToGuestServerMessage::FailedNameAlreadyOccupied; - transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); - break; // Stop reading - } else { - let cert = generate_server_cert(&server_root_cert, name.as_str()); - server_db_addr.send(RegisterServer { - cert: cert.cert.clone(), - name, - }).await.unwrap().unwrap(); - let reply = ToGuestServerMessage::OkAnnounce { - server_cert: cert.cert, - server_prkey: cert.prkey - }; - transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); - } - } - } - } - Err(e) => { - info!("Disconnection: {:?}", e); - break; - } - } - } - } - } - } else if peer_cert_der == *client_root_cert_der { + let transport = Framed::new(stream, codec); + guestserver_handler(transport, &server_db_addr, &server_root_cert).await; + } else if peer_root_cert_der == *client_root_cert_der { info!("Client connection"); } else { error!("Unknown Root Certificate"); @@ -270,3 +219,92 @@ async fn main() { .await .unwrap(); } + +async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec, server_db_addr: &Addr) { + loop { + match transport.next().await { + None => { + info!("Transport returned None"); + break; + } + Some(item) => match item { + Ok(buf) => { + use FromServerCommandMessage::*; + let msg: FromServerCommandMessage = rmp_serde::from_slice(&buf).unwrap(); + info!("{:?}", msg); + match msg { + ChangeName { name } => { + info!("Changing name to {}", name); + // TODO + } + WhoAmI => { + info!("Asked who I am"); + let reply = match server_db_addr.send(FetchName { cert: peer_cert_bytes.clone() }).await.unwrap() { + None => { + info!("I'm not registered anymore!? WTF"); + ToServerCommandReply::GenericFailure + } + Some(name) => { + info!("I am {}", name); + ToServerCommandReply::YouAre { name } + } + }; + transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); + } + } + } + Err(e) => { + info!("Disconnection: {:?}", e); + break; + } + } + } + } +} + +// TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler +async fn guestserver_handler(mut transport: TransportStream, server_db_addr: &Addr, server_root_cert: &Arc) { + loop { + match transport.next().await { + None => { + info!("Transport returned None"); + break; + } + Some(item) => { + match item { + Ok(buf) => { + use FromGuestServerMessage::*; + let msg: FromGuestServerMessage = rmp_serde::from_slice(&buf).unwrap(); + info!("{:?}", msg); + match msg { + Announce { name } => { + info!("Announced with name {}", name); + if server_db_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() { + info!("Name {} already registered!", name); + let reply = ToGuestServerMessage::FailedNameAlreadyOccupied; + transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); + break; // Stop reading + } else { + let cert = generate_server_cert(server_root_cert, name.as_str()); + server_db_addr.send(RegisterServer { + cert: cert.cert.clone(), + name, + }).await.unwrap().unwrap(); + let reply = ToGuestServerMessage::OkAnnounce { + server_cert: cert.cert, + server_prkey: cert.prkey + }; + transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); + } + } + } + } + Err(e) => { + info!("Disconnection: {:?}", e); + break; + } + } + } + } + } +} diff --git a/bonknet_server/src/bin/server.rs b/bonknet_server/src/bin/server.rs index d96609d..3574a69 100644 --- a/bonknet_server/src/bin/server.rs +++ b/bonknet_server/src/bin/server.rs @@ -44,7 +44,7 @@ async fn main() -> std::io::Result<()> { let mut myserver_prkey: Option = None; match transport.next().await { None => { - info!("None in the transport.next() ???"); + panic!("None in the transport"); } Some(item) => match item { Ok(buf) => { @@ -69,17 +69,112 @@ async fn main() -> std::io::Result<()> { } } } + transport.close().await.unwrap(); if let (Some(server_cert), Some(server_prkey)) = (myserver_cert, myserver_prkey) { - let tlsconfig = ClientConfig::builder() + let tlsconfig = Arc::new(ClientConfig::builder() .with_root_certificates(broker_root_cert_store) .with_client_auth_cert(vec![server_cert, root_server_cert], server_prkey.into()) - .unwrap(); - let connector = TlsConnector::from(Arc::new(tlsconfig)); + .unwrap()); + let connector = TlsConnector::from(Arc::clone(&tlsconfig)); let dnsname = ServerName::try_from("localhost").unwrap(); let stream = TcpStream::connect("localhost:2541").await?; let stream = connector.connect(dnsname, stream).await?; - let transport = Framed::new(stream, LengthDelimitedCodec::new()); + let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); + let msg = FromServerConnTypeMessage::SendCommand; + transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); + match transport.next().await { + None => { + panic!("None in the transport"); + } + Some(item) => match item { + Ok(buf) => { + use ToServerConnTypeReply::*; + let msg: ToServerConnTypeReply = rmp_serde::from_slice(&buf).unwrap(); + info!("{:?}", msg); + match msg { + OkSendCommand => { + info!("Stream set in SendCommand mode"); + } + OkSubscribe => { + panic!("Unexpected OkSubscribe"); + } + GenericFailure => { + panic!("Generic Failure during SendCommand"); + } + } + } + Err(e) => { + info!("Disconnection: {:?}", e); + } + } + } + // Begin WhoAmI + let msg = FromServerCommandMessage::WhoAmI; + transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); + match transport.next().await { + None => { + panic!("None in the transport"); + } + Some(item) => match item { + Ok(buf) => { + use ToServerCommandReply::*; + let msg: ToServerCommandReply = rmp_serde::from_slice(&buf).unwrap(); + info!("{:?}", msg); + match msg { + YouAre { name } => { + info!("I am {}", name); + } + GenericFailure => { + panic!("Generic failure during WhoAmI"); + } + _ => { + panic!("Unexpected reply"); + } + } + } + Err(e) => { + info!("Disconnection: {:?}", e); + } + } + } + transport.close().await.expect("Error during transport stream close"); + // Start Subscribe Stream + let connector = TlsConnector::from(Arc::clone(&tlsconfig)); + let dnsname = ServerName::try_from("localhost").unwrap(); + + let stream = TcpStream::connect("localhost:2541").await?; + let stream = connector.connect(dnsname, stream).await?; + let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); + let msg = FromServerConnTypeMessage::Subscribe; + transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); + match transport.next().await { + None => { + panic!("None in the transport"); + } + Some(item) => match item { + Ok(buf) => { + use ToServerConnTypeReply::*; + let msg: ToServerConnTypeReply = rmp_serde::from_slice(&buf).unwrap(); + info!("{:?}", msg); + match msg { + OkSubscribe => { + info!("Stream set in Subscribe mode"); + } + OkSendCommand => { + panic!("Unexpected OkSendCommand"); + } + GenericFailure => { + panic!("Generic Failure during SendCommand"); + } + } + } + Err(e) => { + info!("Disconnection: {:?}", e); + } + } + } + // Subscribe consume transport.for_each(|item| async move { match item { Ok(buf) => { @@ -89,14 +184,6 @@ async fn main() -> std::io::Result<()> { Required { id } => { info!("I'm required with Connection ID {}", id); } - YouAre(name) => match name { - YouAreValues::Registered { name } => { - info!("I am {}", name); - } - YouAreValues::NotRegistered => { - info!("I'm not registered"); - } - } } } Err(e) => { diff --git a/libbonknet/src/lib.rs b/libbonknet/src/lib.rs index a7c2d2b..c6f067f 100644 --- a/libbonknet/src/lib.rs +++ b/libbonknet/src/lib.rs @@ -32,12 +32,32 @@ pub enum RequiredReplyValues { } #[derive(Debug, Serialize, Deserialize)] -pub enum FromServerMessage { - RequiredReply(RequiredReplyValues), +pub enum FromServerConnTypeMessage { + SendCommand, + Subscribe, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ToServerConnTypeReply { + OkSendCommand, + OkSubscribe, + GenericFailure, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum FromServerCommandMessage { ChangeName { name: String }, WhoAmI, } +#[derive(Debug, Serialize, Deserialize)] +pub enum ToServerCommandReply { + NameChanged, + NameNotAvailable, + YouAre { name: String }, + GenericFailure, +} + #[derive(Debug, Serialize, Deserialize)] pub enum YouAreValues { Registered { name: String }, @@ -47,7 +67,6 @@ pub enum YouAreValues { #[derive(Debug, Serialize, Deserialize)] pub enum ToServerMessage { Required { id: String }, - YouAre(YouAreValues), } #[derive(Debug, Serialize, Deserialize)] @@ -57,7 +76,7 @@ pub enum FromGuestServerMessage { #[derive(Debug, Serialize, Deserialize)] pub enum ToGuestServerMessage { - OkAnnounce {server_cert: Vec, server_prkey: Vec}, + OkAnnounce { server_cert: Vec, server_prkey: Vec }, FailedNameAlreadyOccupied, } @@ -69,7 +88,7 @@ pub fn okannounce_to_cert<'a>(server_cert: Vec, server_prkey: Vec) -> (C impl ToGuestServerMessage { pub fn make_okannounce(server_cert: CertificateDer, server_prkey: PrivatePkcs8KeyDer) -> Self { - ToGuestServerMessage::OkAnnounce{ + ToGuestServerMessage::OkAnnounce { server_cert: server_cert.to_vec(), server_prkey: server_prkey.secret_pkcs8_der().to_vec() }