Refactor TransportStream in Server

This commit is contained in:
2024-03-22 22:05:33 +01:00
parent a1b4865b3f
commit 1e4e4bdb53
9 changed files with 331 additions and 35 deletions

View File

@@ -98,17 +98,6 @@ async fn main() -> std::io::Result<()> {
// Load Identity files
let guestserver_ident = LeafCertPair::load_from_file("certs_pem/guestserver.pem").unwrap();
let broker_root = BrokerRootCerts::load_from_file("certs_pem/broker_root_ca_cert.pem").unwrap();
// // Root certs to verify the server is the right one
// let mut broker_root_cert_store = RootCertStore::empty();
// let broker_root_cert_der = load_cert("certs/broker_root_cert.pem").unwrap();
// broker_root_cert_store.add(broker_root_cert_der).unwrap();
// // Public CA that will be used to generate the Server certificate
// let root_server_cert = load_cert("certs/server_root_cert.pem").unwrap();
// // Guest CA
// let root_guestserver_cert = load_cert("certs/guestserver_root_cert.pem").unwrap();
// // Certificate used to do the first authentication
// let guestserver_cert = load_cert("certs/guestserver_cert.pem").unwrap();
// let guestserver_prkey = load_prkey("certs/guestserver_key.pem").unwrap();
// Load TLS Config
let guest_cert_chain = guestserver_ident.fullchain();
let tlsconfig = ClientConfig::builder()

View File

@@ -1,29 +1,227 @@
use std::io::{Error, ErrorKind};
mod transportstream;
use std::path::Path;
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName};
use tokio_rustls::TlsConnector;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tokio_rustls::rustls::{ClientConfig};
use libbonknet::*;
use libbonknet::servermsg::*;
use libbonknet::cert::*;
use uuid::Uuid;
use tracing::{error, info};
use libbonknet::cert::{BrokerRootCerts, LeafCertPair};
use crate::transportstream::*;
#[derive(Clone)]
struct ServerContext<'a> {
identity: LeafCertPair<'a>,
broker_root: BrokerRootCerts<'a>,
my_name: String,
}
impl ServerContext<'_> {
fn tlsconfig(&self) -> ClientConfig {
self.identity.to_tlsclientconfig(&self.broker_root)
}
}
async fn subscribe(ctx: &ServerContext<'_>) -> Result<(), TransportError> {
use ToServerConnTypeReply::*;
let tlsconfig = Arc::new(ctx.tlsconfig());
let mut transport = TransportStream::new(Arc::clone(&tlsconfig)).await?;
let msg = FromServerConnTypeMessage::Subscribe;
match transport.send_and_listen::<_, ToServerConnTypeReply>(&msg).await? {
OkSubscribe => {
info!("Stream set in Subscribe mode");
}
GenericFailure => {
panic!("Generic Failure during SendCommand");
}
others => {
panic!("Unexpected Message type: {:?}", others);
}
}
loop {
use ToServerMessage::*;
let out: Option<FromServerReply>;
match transport.listen_one::<ToServerMessage>().await? {
Msg { reply_id, body } => {
use libbonknet::servermsg::ToServerMessageBody::*;
match body {
Request { conn_id } => {
info!("I'm required with Connection ID {}", conn_id);
out = Some(FromServerReply::Msg {
reply_id,
body: FromServerReplyBody::RequestAccepted,
});
// TODO: Spawn Datastream
tokio::spawn(datastream(ctx.tlsconfig(), conn_id));
}
}
}
Ping => {
info!("Ping!");
out = Some(FromServerReply::Pong);
}
}
if let Some(msg) = out {
transport.send(&msg).await?;
}
}
}
async fn datastream(tlsconfig: ClientConfig, conn_id: Uuid) -> Result<(), TransportError> {
use TransportError::StreamError;
use ToPeerDataStream::*;
let mut transport = TransportStream::new(Arc::new(tlsconfig)).await?;
let msg = FromServerConnTypeMessage::OpenDataStream(conn_id);
match transport.send_and_listen::<_, ToPeerDataStream>(&msg).await? {
OkDataStreamRequestAccepted => {
info!("Data Stream Accepted. Waiting for Open...");
}
Refused => {
panic!("Refused");
}
other => {
panic!("Unexpected response: {:?}", other);
}
}
match transport.listen_one().await? {
OkDataStreamOpen => {
info!("Data Stream Open!. Connecting Streams.");
}
Revoked => {
panic!("Data Stream Revoked!");
}
Refused => {
panic!("Refused");
}
other => {
panic!("Unexpected response: {:?}", other);
}
}
// Initialize outbound stream
let mut inbound = transport.into_inner().into_inner();
let mut outbound = TcpStream::connect("127.0.0.1:22").await.map_err(StreamError)?;
match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await {
Ok(bytes_copied) => info!("{bytes_copied:?}"),
Err(e) => error!("Error during copy: {e}"),
}
Ok(())
}
async fn announce<'a>(ctx: &ServerContext<'_>) -> Result<LeafCertPair<'a>, TransportError> {
use ToGuestServerMessage::*;
let mut transport = TransportStream::new(Arc::new(ctx.tlsconfig())).await.unwrap();
let msg = FromGuestServerMessage::Announce { name: ctx.my_name.clone() };
transport.send(&msg).await?;
for i in 0..10 {
match transport.listen_one().await? {
OkAnnounce(payload) => {
info!("Ok Announce");
transport.close().await?;
return Ok(payload.parse());
}
FailedNameAlreadyOccupied => {
let new_name = format!("ERROR_{}_{}", &ctx.my_name, i + 1);
error!("Failed Announce, name already occupied. Using {}", &new_name);
let msg = FromGuestServerMessage::Announce { name: new_name };
transport.send(&msg).await?;
}
}
}
panic!("Out of retry");
}
async fn server_name_confirmation<'a>(ctx: &ServerContext<'_>) -> Result<(), TransportError> {
use ToServerConnTypeReply::*;
use ToServerCommandReply::*;
let mut transport = TransportStream::new(Arc::new(ctx.tlsconfig())).await?;
// Declare Conn Type
let msg = FromServerConnTypeMessage::SendCommand;
match transport.send_and_listen(&msg).await? {
OkSendCommand => {}
e => {
panic!("Error during ConnType Declare: {:?}", e);
}
}
// Ask Name
let msg = FromServerCommandMessage::WhoAmI;
match transport.send_and_listen(&msg).await? {
YouAre { name } => {
if ctx.my_name == name {
return Ok(());
}
}
other => {
panic!("Unexpected response: {:?}", other);
}
}
// If name doesn't correspond, try to ChangeName. 10 retry. If they fail, keep the actual one without panic.
let msg = FromServerCommandMessage::ChangeName { name: ctx.my_name.clone() };
transport.send(&msg).await?;
for i in 0..10 {
match transport.listen_one().await? {
NameChanged => {
return Ok(());
}
NameNotAvailable => {
let msg = FromServerCommandMessage::ChangeName { name: format!("ERROR_{}_{}", ctx.my_name, i + 1) };
transport.send(&msg).await?;
}
other => {
panic!("Unexpected response: {:?}", other);
}
}
}
panic!("Exhausted Announce Retry");
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
// Tracing Subscriber
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber).unwrap();
// Server Name
// TODO: from config
let my_name = "cicciopizza";
// Load Identity files
let guestserver_ident = LeafCertPair::load_from_file("certs_pem/guestserver.pem").unwrap();
// TODO: from config using std
let my_name = String::from("cicciopizza");
let serverident_path = Path::new("server/serveridentity.pem"); // "/etc/bonknet/identity.pem"
let guestserverident_path = Path::new("server/guestidentity.pem"); // "/etc/bonknet/guestidentity.pem"
let broker_root_path = Path::new("certs_pem/broker_root_ca_cert.pem"); // "/etc/bonknet/broker_root_ca_cert.pem"
// Load Broker Root file
if !(broker_root_path.try_exists().unwrap() && broker_root_path.is_file()) {
panic!("No Broker Root file");
}
let broker_root = BrokerRootCerts::load_from_file("certs_pem/broker_root_ca_cert.pem").unwrap();
// TODO: ACTOR MODEL per gestione transport in maniera intelligente?
Ok(())
// Load Identity Files (if needed, contact the broker for generation)
let exists_serverident = serverident_path.try_exists().unwrap() && serverident_path.is_file();
let exists_guestident = guestserverident_path.try_exists().unwrap() && guestserverident_path.is_file();
// Do Guest registration and Name confirmation
let ctx = if !exists_serverident && exists_guestident {
info!("No Server Identity. Starting Guest Announce...");
// No Server, Yes Guest -> Use Guest to retrieve Server identity
let guest_ident = LeafCertPair::load_from_file(guestserverident_path).unwrap();
let ctx = ServerContext { identity: guest_ident, broker_root: broker_root.clone(), my_name: my_name.clone() };
let server_ident = announce(&ctx).await.unwrap();
server_ident.save_into_file(serverident_path).unwrap();
ServerContext { identity: server_ident, broker_root: broker_root.clone(), my_name: my_name.clone() }
} else if exists_serverident {
// Yes Server -> Use Server file as identity
let server_ident = LeafCertPair::load_from_file(serverident_path).unwrap();
let ctx = ServerContext { identity: server_ident, broker_root: broker_root.clone(), my_name: my_name.clone() };
server_name_confirmation(&ctx).await.unwrap();
ctx
} else {
// No identity file present
panic!("No Identity file found");
};
// Start Server Main
let ctx = Arc::new(ctx);
loop {
if let Err(e) = subscribe(&ctx).await {
error!("Subscribe Task aborted due to {}", e);
error!("Restoring Subscribe Task...");
}
}
}

View File

@@ -0,0 +1,80 @@
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::TlsConnector;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use serde::{Serialize};
use serde::de::DeserializeOwned;
use thiserror::Error;
use tokio_rustls::rustls::ClientConfig;
#[derive(Error, Debug)]
pub enum TransportError {
#[error("Stream Terminated, next() returned None")]
StreamTerminated,
#[error("Stream Error")]
StreamError(std::io::Error),
#[error("Serialization Error")]
SerializeError(rmp_serde::encode::Error),
#[error("Deserialization Error")]
DeserializeError(rmp_serde::decode::Error),
}
pub struct TransportStream {
transport: Framed<TlsStream<TcpStream>, LengthDelimitedCodec>,
}
impl TransportStream {
pub async fn new(tlsconfig: Arc<ClientConfig>) -> Result<Self, TransportError> {
use TransportError::StreamError;
let connector = TlsConnector::from(tlsconfig);
let dnsname = ServerName::try_from("localhost").unwrap();
let stream = TcpStream::connect("localhost:2541").await.map_err(StreamError)?;
let stream = connector.connect(dnsname, stream).await.map_err(StreamError)?;
let transport = Framed::new(stream, LengthDelimitedCodec::new());
Ok(TransportStream { transport })
}
pub fn into_inner(self) -> Framed<TlsStream<TcpStream>, LengthDelimitedCodec> {
self.transport
}
pub async fn send<T: Serialize>(&mut self, msg: &T) -> Result<(), TransportError> {
use TransportError::*;
self.transport.send(rmp_serde::to_vec(&msg).map_err(SerializeError)?.into()).await.map_err(StreamError)?;
Ok(())
}
pub async fn send_and_listen<T: Serialize, U: DeserializeOwned>(&mut self, msg: &T) -> Result<U, TransportError> {
self.send(msg).await?;
self.listen_one().await
}
pub async fn listen_one<T: DeserializeOwned>(&mut self) -> Result<T, TransportError> {
use TransportError::*;
match self.transport.next().await {
None => {
// Stream Terminated
Err(StreamTerminated)
}
Some(item) => match item {
Ok(buf) => {
let msg: T = rmp_serde::from_slice(&buf).map_err(DeserializeError)?;
Ok(msg)
}
Err(e) => {
// Disconnection
Err(StreamError(e))
}
}
}
}
pub async fn close(mut self) -> Result<(), TransportError> {
self.transport.close().await.map_err(TransportError::StreamError)
}
}