Start refactoring for a better code inside the broker

This commit is contained in:
2024-02-26 23:56:13 +01:00
parent 88aeb25fdf
commit 4604beed36
6 changed files with 84 additions and 78 deletions

View File

@@ -0,0 +1,20 @@
use rcgen::{Certificate, CertificateParams, DnType};
pub struct ServerCert {
pub cert: Vec<u8>,
pub prkey: Vec<u8>,
}
pub fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert {
let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]);
params.distinguished_name.push(DnType::CommonName, name);
params.use_authority_key_identifier_extension = true;
params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature);
params
.extended_key_usages
.push(rcgen::ExtendedKeyUsagePurpose::ClientAuth);
let certificate = Certificate::from_params(params).unwrap();
let cert = certificate.serialize_der_with_signer(root_cert).unwrap();
let prkey = certificate.serialize_private_key_der();
ServerCert { cert, prkey }
}

View File

@@ -5,7 +5,7 @@ use futures::SinkExt;
use thiserror::Error; use thiserror::Error;
use tracing::{info, error, warn}; use tracing::{info, error, warn};
use libbonknet::ToPeerDataStream; use libbonknet::ToPeerDataStream;
use crate::TransportStream; use crate::streamutils::*;
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Error, Debug)] #[derive(Error, Debug)]

View File

@@ -2,11 +2,14 @@ mod servercertdb;
mod pendingdataconndb; mod pendingdataconndb;
mod servermanager; mod servermanager;
mod dataconnmanager; mod dataconnmanager;
mod streamutils;
mod certutils;
use servercertdb::*; use servercertdb::*;
use pendingdataconndb::*; use pendingdataconndb::*;
use servermanager::*; use servermanager::*;
use dataconnmanager::*; use dataconnmanager::*;
use streamutils::*;
use actix::prelude::*; use actix::prelude::*;
use std::sync::Arc; use std::sync::Arc;
use libbonknet::*; use libbonknet::*;
@@ -19,32 +22,20 @@ use actix_server::Server;
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::ServiceFactoryExt as _; use actix_service::ServiceFactoryExt as _;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; use rcgen::{Certificate, CertificateParams, KeyPair};
use tokio::io::{ReadHalf, WriteHalf}; use rustls::pki_types::CertificateDer;
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
type TransportStreamTx = FramedWrite<WriteHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
type TransportStreamRx = FramedRead<ReadHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
struct ServerCert { #[derive(Clone)]
cert: Vec<u8>, struct BrokerContext {
prkey: Vec<u8>, server_root_cert_der: CertificateDer<'static>,
} client_root_cert_der: CertificateDer<'static>,
guestserver_root_cert_der: CertificateDer<'static>,
fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert { scdb_addr: Addr<ServerCertDB>,
let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]); pdcm_addr: Addr<PendingDataConnManager>,
params.distinguished_name.push(DnType::CommonName, name); sm_addr: Addr<ServerManager>,
params.use_authority_key_identifier_extension = true;
params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature);
params
.extended_key_usages
.push(rcgen::ExtendedKeyUsagePurpose::ClientAuth);
let certificate = Certificate::from_params(params).unwrap();
let cert = certificate.serialize_der_with_signer(root_cert).unwrap();
let prkey = certificate.serialize_private_key_der();
ServerCert { cert, prkey }
} }
@@ -72,53 +63,47 @@ async fn main() {
let server_verifier = WebPkiClientVerifier::builder(Arc::new(broker_root_store)).build().unwrap(); let server_verifier = WebPkiClientVerifier::builder(Arc::new(broker_root_store)).build().unwrap();
// Configure TLS // Configure TLS
let server_tlsconfig = ServerConfig::builder() let server_tlsconfig = ServerConfig::builder()
// .with_no_client_auth()
.with_client_cert_verifier(server_verifier) .with_client_cert_verifier(server_verifier)
.with_single_cert(vec![broker_cert_der.clone(), broker_root_cert_der.clone()], broker_prkey_der.into()) .with_single_cert(vec![broker_cert_der.clone(), broker_root_cert_der.clone()], broker_prkey_der.into())
.unwrap(); .unwrap();
let server_acceptor = RustlsAcceptor::new(server_tlsconfig); let server_acceptor = RustlsAcceptor::new(server_tlsconfig);
let server_root_cert_der = Arc::new(server_root_cert_der); let scdb_addr = ServerCertDB::new().start();
let dcm_addr = DataConnManager::new().start();
let pdcm_addr = PendingDataConnManager::new(dcm_addr).start();
let sm_addr = ServerManager::new(pdcm_addr.clone()).start();
let server_root_prkey = KeyPair::from_der(server_root_prkey_der.secret_pkcs8_der()).unwrap(); let server_root_prkey = KeyPair::from_der(server_root_prkey_der.secret_pkcs8_der()).unwrap();
let client_root_cert_der = Arc::new(client_root_cert_der);
let guestserver_root_cert_der = Arc::new(guestserver_root_cert_der);
let server_root_cert = Arc::new(Certificate::from_params(CertificateParams::from_ca_cert_der( let server_root_cert = Arc::new(Certificate::from_params(CertificateParams::from_ca_cert_der(
&server_root_cert_der, &server_root_cert_der,
server_root_prkey server_root_prkey
).unwrap()).unwrap()); ).unwrap()).unwrap());
let scdb_addr = ServerCertDB::new().start(); let ctx = Arc::new(BrokerContext {
let dcm_addr = DataConnManager::new().start(); server_root_cert_der,
let pdcm_addr = PendingDataConnManager::new(dcm_addr.clone()).start(); client_root_cert_der,
let sm_addr = ServerManager::new(pdcm_addr.clone()).start(); guestserver_root_cert_der,
scdb_addr,
pdcm_addr,
sm_addr,
});
Server::build() Server::build()
.bind("server-command", ("localhost", 2541), move || { .bind("server-command", ("localhost", 2541), move || {
let server_root_cert_der = Arc::clone(&server_root_cert_der); let ctx = Arc::clone(&ctx);
let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let scdb_addr = scdb_addr.clone();
let pdcm_addr = pdcm_addr.clone();
let sm_addr = sm_addr.clone();
// Set up TLS service factory // Set up TLS service factory
server_acceptor server_acceptor
.clone() .clone()
.map_err(|err| println!("Rustls error: {:?}", err)) .map_err(|err| println!("Rustls error: {:?}", err))
.and_then(move |stream: TlsStream<TcpStream>| { .and_then(move |stream: TlsStream<TcpStream>| {
let server_root_cert_der = Arc::clone(&server_root_cert_der); let ctx = Arc::clone(&ctx);
let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let scdb_addr = scdb_addr.clone();
let pdcm_addr = pdcm_addr.clone();
let sm_addr = sm_addr.clone();
async move { async move {
let peer_certs = stream.get_ref().1.peer_certificates().unwrap(); let peer_certs = stream.get_ref().1.peer_certificates().unwrap();
let peer_cert_bytes = peer_certs.first().unwrap().to_vec(); let peer_cert_bytes = peer_certs.first().unwrap().to_vec();
let peer_root_cert_der = peer_certs.last().unwrap().clone(); let peer_root_cert_der = peer_certs.last().unwrap().clone();
if peer_root_cert_der == *server_root_cert_der { if peer_root_cert_der == ctx.server_root_cert_der {
info!("Server connection"); info!("Server connection");
let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
match transport.next().await { match transport.next().await {
@@ -127,8 +112,7 @@ async fn main() {
} }
Some(item) => match item { Some(item) => match item {
Ok(buf) => { Ok(buf) => {
use libbonknet::servermsg::{FromServerConnTypeMessage, ToServerConnTypeReply}; use FromServerConnTypeMessage::*;
use libbonknet::servermsg::FromServerConnTypeMessage::*;
let msg: FromServerConnTypeMessage = rmp_serde::from_slice(&buf).unwrap(); let msg: FromServerConnTypeMessage = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg); info!("{:?}", msg);
match msg { match msg {
@@ -136,11 +120,11 @@ async fn main() {
info!("SendCommand Stream"); info!("SendCommand Stream");
let reply = ToServerConnTypeReply::OkSendCommand; let reply = ToServerConnTypeReply::OkSendCommand;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
server_command_handler(transport, peer_cert_bytes, scdb_addr).await; server_command_handler(&ctx, transport, peer_cert_bytes).await;
} }
Subscribe => { Subscribe => {
info!("Subscribe Stream"); info!("Subscribe Stream");
let name = match scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() { let name = match ctx.scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() {
None => { None => {
error!("Cert has no name assigned!"); error!("Cert has no name assigned!");
let reply = ToServerConnTypeReply::GenericFailure; let reply = ToServerConnTypeReply::GenericFailure;
@@ -151,12 +135,12 @@ async fn main() {
}; };
let reply = ToServerConnTypeReply::OkSubscribe; let reply = ToServerConnTypeReply::OkSubscribe;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
server_subscribe_handler(transport, name, sm_addr).await; server_subscribe_handler(&ctx, transport, name).await;
} }
OpenDataStream(conn_id) => { OpenDataStream(conn_id) => {
info!("OpenDataStream with {:?}", conn_id); info!("OpenDataStream with {:?}", conn_id);
let msg = RegisterStream::server(conn_id, transport); let msg = RegisterStream::server(conn_id, transport);
pdcm_addr.send(msg).await.unwrap().unwrap(); ctx.pdcm_addr.send(msg).await.unwrap().unwrap();
} }
} }
} }
@@ -166,17 +150,16 @@ async fn main() {
} }
} }
info!("Server Task terminated!"); info!("Server Task terminated!");
} else if peer_root_cert_der == *guestserver_root_cert_der { } else if peer_root_cert_der == ctx.guestserver_root_cert_der {
info!("GuestServer connection"); info!("GuestServer connection");
let server_root_cert = Arc::clone(&server_root_cert);
let codec = LengthDelimitedCodec::new(); let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec); let transport = Framed::new(stream, codec);
guestserver_handler(transport, scdb_addr, &server_root_cert).await; guestserver_handler(&ctx, transport, &server_root_cert).await;
} else if peer_root_cert_der == *client_root_cert_der { } else if peer_root_cert_der == ctx.client_root_cert_der {
info!("Client connection"); info!("Client connection");
let codec = LengthDelimitedCodec::new(); let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec); let transport = Framed::new(stream, codec);
client_handler(transport, sm_addr, pdcm_addr).await; client_handler(&ctx, transport).await;
} else { } else {
error!("Unknown Root Certificate"); error!("Unknown Root Certificate");
} }
@@ -190,8 +173,8 @@ async fn main() {
.unwrap(); .unwrap();
} }
async fn server_subscribe_handler(transport: TransportStream, name: String, sm_addr: Addr<ServerManager>) { async fn server_subscribe_handler(ctx: &BrokerContext, transport: TransportStream, name: String) {
match sm_addr.send(StartTransporter { name, transport }).await.unwrap() { match ctx.sm_addr.send(StartTransporter { name, transport }).await.unwrap() {
Ok(_) => { Ok(_) => {
info!("Stream sent to the manager"); info!("Stream sent to the manager");
} }
@@ -201,7 +184,7 @@ async fn server_subscribe_handler(transport: TransportStream, name: String, sm_a
} }
} }
async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec<u8>, server_db_addr: Addr<ServerCertDB>) { async fn server_command_handler(ctx: &BrokerContext, mut transport: TransportStream, peer_cert_bytes: Vec<u8>) {
loop { loop {
match transport.next().await { match transport.next().await {
None => { None => {
@@ -216,7 +199,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
match msg { match msg {
ChangeName { name } => { ChangeName { name } => {
info!("Changing name to {}", name); info!("Changing name to {}", name);
match server_db_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() { match ctx.scdb_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() {
None => { None => {
info!("Nothing to unregister"); info!("Nothing to unregister");
} }
@@ -224,7 +207,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
warn!("Unregistered from old name {}", old_name); warn!("Unregistered from old name {}", old_name);
} }
} }
let reply = match server_db_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() { let reply = match ctx.scdb_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() {
Ok(_) => { Ok(_) => {
info!("Registered!"); info!("Registered!");
ToServerCommandReply::NameChanged ToServerCommandReply::NameChanged
@@ -238,7 +221,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
} }
WhoAmI => { WhoAmI => {
info!("Asked who I am"); info!("Asked who I am");
let reply = match server_db_addr.send(FetchName { cert: peer_cert_bytes.clone() }).await.unwrap() { let reply = match ctx.scdb_addr.send(FetchName { cert: peer_cert_bytes.clone() }).await.unwrap() {
None => { None => {
info!("I'm not registered anymore!? WTF"); info!("I'm not registered anymore!? WTF");
ToServerCommandReply::GenericFailure ToServerCommandReply::GenericFailure
@@ -261,8 +244,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
} }
} }
// TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler async fn guestserver_handler(ctx: &BrokerContext, mut transport: TransportStream, server_root_cert: &Certificate) {
async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Addr<ServerCertDB>, server_root_cert: &Arc<Certificate>) {
loop { loop {
match transport.next().await { match transport.next().await {
None => { None => {
@@ -278,14 +260,14 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Add
match msg { match msg {
Announce { name } => { Announce { name } => {
info!("Announced with name {}", name); info!("Announced with name {}", name);
if server_db_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() { if ctx.scdb_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() {
info!("Name {} already registered!", name); info!("Name {} already registered!", name);
let reply = ToGuestServerMessage::FailedNameAlreadyOccupied; let reply = ToGuestServerMessage::FailedNameAlreadyOccupied;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
break; // Stop reading break; // Stop reading
} else { } else {
let cert = generate_server_cert(server_root_cert, name.as_str()); let cert = certutils::generate_server_cert(server_root_cert, name.as_str());
server_db_addr.send(RegisterServer { ctx.scdb_addr.send(RegisterServer {
cert: cert.cert.clone(), cert: cert.cert.clone(),
name, name,
}).await.unwrap().unwrap(); }).await.unwrap().unwrap();
@@ -308,7 +290,7 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Add
} }
} }
async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerManager>, pdcm_addr: Addr<PendingDataConnManager>) { async fn client_handler(ctx: &BrokerContext, mut transport: TransportStream) {
loop { loop {
match transport.next().await { match transport.next().await {
None => { None => {
@@ -323,7 +305,7 @@ async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerMana
match msg { match msg {
FromClientCommand::RequestServer { name } => { FromClientCommand::RequestServer { name } => {
info!("REQUESTED SERVER {}", name); info!("REQUESTED SERVER {}", name);
let data = sm_addr.send(RequestServer { name }).await.unwrap(); let data = ctx.sm_addr.send(RequestServer { name }).await.unwrap();
match data { match data {
Ok(client_conn_id) => { Ok(client_conn_id) => {
let reply = ToClientResponse::OkRequest { conn_id: client_conn_id }; let reply = ToClientResponse::OkRequest { conn_id: client_conn_id };
@@ -337,14 +319,14 @@ async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerMana
} }
FromClientCommand::ServerList => { FromClientCommand::ServerList => {
info!("Requested ServerList"); info!("Requested ServerList");
let data = sm_addr.send(GetServerList {}).await.unwrap(); let data = ctx.sm_addr.send(GetServerList {}).await.unwrap();
let reply = ToClientResponse::OkServerList { data }; let reply = ToClientResponse::OkServerList { data };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
} }
FromClientCommand::UpgradeToDataStream(conn_id) => { FromClientCommand::UpgradeToDataStream(conn_id) => {
info!("Upgrade to DataStream with conn_id {:?}", conn_id); info!("Upgrade to DataStream with conn_id {:?}", conn_id);
let msg = RegisterStream::client(conn_id, transport); let msg = RegisterStream::client(conn_id, transport);
pdcm_addr.send(msg).await.unwrap().unwrap(); ctx.pdcm_addr.send(msg).await.unwrap().unwrap();
break; break;
} }
} }

View File

@@ -1,15 +1,11 @@
use actix::prelude::*; use actix::prelude::*;
use actix_tls::accept::rustls_0_22::TlsStream;
use futures::SinkExt; use futures::SinkExt;
use thiserror::Error; use thiserror::Error;
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::{error, info}; use tracing::{error, info};
use uuid::Uuid; use uuid::Uuid;
use libbonknet::*; use libbonknet::*;
use crate::dataconnmanager::{DataConnManager, StartDataBridge}; use crate::dataconnmanager::{DataConnManager, StartDataBridge};
use crate::streamutils::*;
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
/* /*
L'idea e' che il database deve avere una riga per ogni connessione dati in nascita. L'idea e' che il database deve avere una riga per ogni connessione dati in nascita.

View File

@@ -10,7 +10,7 @@ use tokio::sync::{Mutex, oneshot};
use tokio_util::bytes::{Bytes, BytesMut}; use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{debug, error, info}; use tracing::{debug, error, info};
use crate::{TransportStream, TransportStreamRx, TransportStreamTx}; use crate::streamutils::*;
use uuid::Uuid; use uuid::Uuid;
use libbonknet::servermsg::*; use libbonknet::servermsg::*;
use crate::pendingdataconndb::*; use crate::pendingdataconndb::*;

View File

@@ -0,0 +1,8 @@
use actix_tls::accept::rustls_0_22::TlsStream;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec};
pub type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
pub type TransportStreamTx = FramedWrite<WriteHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
pub type TransportStreamRx = FramedRead<ReadHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;