Compare commits

...

8 Commits

15 changed files with 1490 additions and 1654 deletions

1365
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -13,14 +13,18 @@ actix-rt = "2.9.0"
actix-server = "2.3.0" actix-server = "2.3.0"
actix-service = "2.0.2" actix-service = "2.0.2"
actix-tls = { version = "3.3.0", features = ["rustls-0_22"] } actix-tls = { version = "3.3.0", features = ["rustls-0_22"] }
tokio = { version = "1", features = ["io-util", "sync", "time", "macros"] }
rustls = "0.22.2" rustls = "0.22.2"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
futures = "0.3" futures = "0.3"
thiserror = "1.0.56" thiserror = "1.0.56"
tokio-util = { version = "0.7.10", features = ["codec"] } tokio-util = { version = "0.7.10", features = ["codec"] }
serde = "1"
rmp-serde = "1.1.2" rmp-serde = "1.1.2"
rcgen = { version = "0.12.1", features = ["x509-parser"] } rcgen = { version = "0.12.1", features = ["x509-parser"] }
rand = "0.8.5"
uuid = { version = "1.7.0", features = ["v4", "serde"] }
[[bin]] [[bin]]
name = "init_certs" name = "init_certs"

View File

@@ -0,0 +1,104 @@
use actix::prelude::*;
use uuid::Uuid;
use std::collections::HashMap;
use futures::SinkExt;
use thiserror::Error;
use tracing::{info, error, warn};
use libbonknet::ToPeerDataStream;
use crate::TransportStream;
#[allow(dead_code)]
#[derive(Error, Debug)]
pub enum DataConnManagerError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),DataConnManagerError>")]
pub struct StartDataBridge {
pub client_conn_id: Uuid,
pub server_transport: TransportStream,
pub client_transport: TransportStream,
}
#[derive(Message)]
#[rtype(result = "()")]
pub struct StopDataBridge {
pub client_conn_id: Uuid,
}
type ClientConnId = Uuid;
struct Connection {
proxyhandler: SpawnHandle,
}
pub struct DataConnManager {
conns: HashMap<ClientConnId, Connection>
}
impl DataConnManager {
pub fn new() -> Self {
Self { conns: HashMap::new() }
}
}
impl Actor for DataConnManager {
type Context = Context<Self>;
}
impl Handler<StopDataBridge> for DataConnManager {
type Result = ();
fn handle(&mut self, msg: StopDataBridge, ctx: &mut Self::Context) -> Self::Result {
match self.conns.remove(&msg.client_conn_id) {
None => warn!("Stopped Data Bridge {} was not in memory", msg.client_conn_id),
Some(conn) => {
if ctx.cancel_future(conn.proxyhandler) {
info!("Stopped Data Bridge {}", msg.client_conn_id);
} else {
info!("Stopped Data Bridge {} was with dead task", msg.client_conn_id);
}
},
}
}
}
impl Handler<StartDataBridge> for DataConnManager {
type Result = Result<(), DataConnManagerError>;
fn handle(&mut self, mut msg: StartDataBridge, ctx: &mut Self::Context) -> Self::Result {
let client_conn_id = msg.client_conn_id;
let handler = ctx.spawn(async move {
// Send to the streams the OK DATA OPEN message
let okmsg = ToPeerDataStream::OkDataStreamOpen;
if let Err(e) = tokio::try_join!(
msg.client_transport.send(rmp_serde::to_vec(&okmsg).unwrap().into()),
msg.server_transport.send(rmp_serde::to_vec(&okmsg).unwrap().into()),
) {
error!("Error during OkDataStreamOpen send: {:?}", e);
// TODO: potrei voler trasformare questa funzione in ResponseActFuture cosi che
// in caso ci sia fallimento su questo send l'errore possa venir riportato direttamente
// al PendingDataConnDb senza bisogno di gestione manuale?
// Da studiare perche non per forza c'e bisogno che il Pending sappia che c'e stato
// fallimento in questa fase.
} else {
let mut client_stream = msg.client_transport.into_inner();
let mut server_stream = msg.server_transport.into_inner();
match tokio::io::copy_bidirectional(&mut client_stream, &mut server_stream).await {
Ok((to_server, to_client)) => info!("DataConn closed with {to_server}B to server and {to_client}B to client"),
Err(e) => error!("Error during DataConn: {e:?}"),
}
}
msg.client_conn_id
}.into_actor(self).map(|res, _a, c| {
c.notify(StopDataBridge { client_conn_id: res });
}));
if let Some(other_conn) = self.conns.insert(client_conn_id, Connection { proxyhandler: handler }) {
ctx.cancel_future(other_conn.proxyhandler);
warn!("During init of Conn {client_conn_id} another connection has been found and is now closed.")
}
Ok(())
}
}

View File

@@ -1,18 +1,32 @@
mod servercertdb;
mod pendingdataconndb;
mod servermanager;
mod dataconnmanager;
use servercertdb::*;
use pendingdataconndb::*;
use servermanager::*;
use dataconnmanager::*;
use actix::prelude::*; use actix::prelude::*;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use libbonknet::{load_cert, load_prkey, FromServerMessage, RequiredReplyValues, FromGuestServerMessage, ToGuestServerMessage}; use libbonknet::*;
use libbonknet::servermsg::*;
use libbonknet::clientmsg::*;
use rustls::{RootCertStore, ServerConfig}; use rustls::{RootCertStore, ServerConfig};
use rustls::server::WebPkiClientVerifier; use rustls::server::WebPkiClientVerifier;
use actix_tls::accept::rustls_0_22::{Acceptor as RustlsAcceptor, TlsStream}; use actix_tls::accept::rustls_0_22::{Acceptor as RustlsAcceptor, TlsStream};
use actix_server::Server; use actix_server::Server;
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
use actix_service::{ServiceFactoryExt as _}; use actix_service::ServiceFactoryExt as _;
use futures::{StreamExt, SinkExt}; use futures::{SinkExt, StreamExt};
use thiserror::Error; use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec};
use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tracing::{error, info, warn};
use tracing::{info, error};
use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; use rcgen::{Certificate, CertificateParams, DnType, KeyPair};
use tokio::io::{ReadHalf, WriteHalf};
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
type TransportStreamTx = FramedWrite<WriteHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
type TransportStreamRx = FramedRead<ReadHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
struct ServerCert { struct ServerCert {
cert: Vec<u8>, cert: Vec<u8>,
@@ -21,7 +35,7 @@ struct ServerCert {
fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert { fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert {
let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]); let mut params = CertificateParams::new(vec!["entity.other.host".into(), format!("bonk.server.{name}")]);
params.distinguished_name.push(DnType::CommonName, format!("{name}")); params.distinguished_name.push(DnType::CommonName, name);
params.use_authority_key_identifier_extension = true; params.use_authority_key_identifier_extension = true;
params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature); params.key_usages.push(rcgen::KeyUsagePurpose::DigitalSignature);
params params
@@ -33,86 +47,6 @@ fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert {
ServerCert { cert, prkey } ServerCert { cert, prkey }
} }
#[derive(Error, Debug)]
enum DBError {
#[error("Certificate is already registered with name {0}")]
CertAlreadyRegistered(String),
// #[error("Generic Failure")]
// GenericFailure,
}
#[derive(Message)]
#[rtype(result = "bool")]
struct IsNameRegistered {
name: String,
}
#[derive(Message)]
#[rtype(result = "Result<(), DBError>")]
struct RegisterServer {
cert: Vec<u8>,
name: String,
}
// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique
struct ServerCertDB {
db: HashMap<Vec<u8>, String>, // Cert to Name
}
impl Actor for ServerCertDB {
type Context = Context<Self>;
}
impl Handler<RegisterServer> for ServerCertDB {
type Result = Result<(), DBError>;
fn handle(&mut self, msg: RegisterServer, _ctx: &mut Self::Context) -> Self::Result {
match self.db.get(&msg.cert) {
None => {
self.db.insert(msg.cert, msg.name);
Ok(())
}
Some(name) => {
Err(DBError::CertAlreadyRegistered(name.clone()))
}
}
}
}
impl Handler<IsNameRegistered> for ServerCertDB {
type Result = bool;
fn handle(&mut self, msg: IsNameRegistered, _ctx: &mut Self::Context) -> Self::Result {
self.db.values().any(|x| *x == msg.name)
}
}
struct GuestServerConnection {
stream: TlsStream<TcpStream>,
}
impl Actor for GuestServerConnection {
type Context = Context<Self>;
}
struct ServerConnection<T: 'static> {
stream: Framed<TlsStream<TcpStream>, T>,
name: String
}
impl<T> ServerConnection<T> {
fn new(stream: TlsStream<TcpStream>, codec: T) -> Self {
let stream = Framed::new(stream, codec);
ServerConnection {
stream,
name: "Polnareffland1".into(),
}
}
}
impl<T> Actor for ServerConnection<T> {
type Context = Context<Self>;
}
#[actix_rt::main] #[actix_rt::main]
async fn main() { async fn main() {
@@ -149,13 +83,14 @@ async fn main() {
let client_root_cert_der = Arc::new(client_root_cert_der); let client_root_cert_der = Arc::new(client_root_cert_der);
let guestserver_root_cert_der = Arc::new(guestserver_root_cert_der); let guestserver_root_cert_der = Arc::new(guestserver_root_cert_der);
let server_root_cert = Arc::new(Certificate::from_params(CertificateParams::from_ca_cert_der( let server_root_cert = Arc::new(Certificate::from_params(CertificateParams::from_ca_cert_der(
&*server_root_cert_der, &server_root_cert_der,
server_root_prkey server_root_prkey
).unwrap()).unwrap()); ).unwrap()).unwrap());
let server_db_addr = ServerCertDB { let scdb_addr = ServerCertDB::new().start();
db: HashMap::new(), let dcm_addr = DataConnManager::new().start();
}.start(); let pdcm_addr = PendingDataConnManager::new(dcm_addr.clone()).start();
let sm_addr = ServerManager::new(pdcm_addr.clone()).start();
Server::build() Server::build()
.bind("server-command", ("localhost", 2541), move || { .bind("server-command", ("localhost", 2541), move || {
@@ -163,7 +98,9 @@ async fn main() {
let client_root_cert_der = Arc::clone(&client_root_cert_der); let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let server_db_addr = server_db_addr.clone(); let scdb_addr = scdb_addr.clone();
let pdcm_addr = pdcm_addr.clone();
let sm_addr = sm_addr.clone();
// Set up TLS service factory // Set up TLS service factory
server_acceptor server_acceptor
@@ -174,90 +111,72 @@ async fn main() {
let client_root_cert_der = Arc::clone(&client_root_cert_der); let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let server_db_addr = server_db_addr.clone(); let scdb_addr = scdb_addr.clone();
let pdcm_addr = pdcm_addr.clone();
let sm_addr = sm_addr.clone();
async move { async move {
let peer_cert_der = stream.get_ref().1.peer_certificates().unwrap().last().unwrap().clone(); let peer_certs = stream.get_ref().1.peer_certificates().unwrap();
if peer_cert_der == *server_root_cert_der { let peer_cert_bytes = peer_certs.first().unwrap().to_vec();
let peer_root_cert_der = peer_certs.last().unwrap().clone();
if peer_root_cert_der == *server_root_cert_der {
info!("Server connection"); info!("Server connection");
let framed = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
framed.for_each(|item| async move { match transport.next().await {
match item { None => {
info!("Connection closed by peer");
}
Some(item) => match item {
Ok(buf) => { Ok(buf) => {
use FromServerMessage::*; use libbonknet::servermsg::{FromServerConnTypeMessage, ToServerConnTypeReply};
let msg: FromServerMessage = rmp_serde::from_slice(&buf).unwrap(); use libbonknet::servermsg::FromServerConnTypeMessage::*;
let msg: FromServerConnTypeMessage = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg); info!("{:?}", msg);
match msg { match msg {
RequiredReply(v) => match v { SendCommand => {
RequiredReplyValues::Ok => { info!("SendCommand Stream");
info!("Required Reply OK") let reply = ToServerConnTypeReply::OkSendCommand;
} transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
RequiredReplyValues::GenericFailure { .. } => { server_command_handler(transport, peer_cert_bytes, scdb_addr).await;
info!("Required Reply Generic Failure")
}
} }
ChangeName { name } => { Subscribe => {
info!("Requested Change Name to Name {}", name); info!("Subscribe Stream");
let name = match scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() {
None => {
error!("Cert has no name assigned!");
let reply = ToServerConnTypeReply::GenericFailure;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
return Ok(());
}
Some(name) => name,
};
let reply = ToServerConnTypeReply::OkSubscribe;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
server_subscribe_handler(transport, name, sm_addr).await;
} }
WhoAmI => { OpenDataStream(conn_id) => {
info!("Requested WhoAmI"); info!("OpenDataStream with {:?}", conn_id);
let msg = RegisterStream::server(conn_id, transport);
pdcm_addr.send(msg).await.unwrap().unwrap();
} }
} }
}, }
Err(e) => { Err(e) => {
info!("Disconnection: {:?}", e); info!("Disconnection: {:?}", e);
} }
} }
}).await; }
info!("Disconnection!"); info!("Server Task terminated!");
} else if peer_cert_der == *guestserver_root_cert_der { } else if peer_root_cert_der == *guestserver_root_cert_der {
info!("GuestServer connection"); info!("GuestServer connection");
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let codec = LengthDelimitedCodec::new(); let codec = LengthDelimitedCodec::new();
let mut transport = Framed::new(stream, codec); let transport = Framed::new(stream, codec);
loop { guestserver_handler(transport, scdb_addr, &server_root_cert).await;
match transport.next().await { } else if peer_root_cert_der == *client_root_cert_der {
None => {
info!("Transport returned None");
}
Some(item) => {
match item {
Ok(buf) => {
use FromGuestServerMessage::*;
let msg: FromGuestServerMessage = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
Announce { name } => {
info!("Announced with name {}", name);
if server_db_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() {
info!("Name {} already registered!", name);
let reply = ToGuestServerMessage::FailedNameAlreadyOccupied;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
break; // Stop reading
} else {
let cert = generate_server_cert(&server_root_cert, name.as_str());
server_db_addr.send(RegisterServer {
cert: cert.cert.clone(),
name,
}).await.unwrap().unwrap();
let reply = ToGuestServerMessage::OkAnnounce {
server_cert: cert.cert,
server_prkey: cert.prkey
};
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
break;
}
}
}
}
}
} else if peer_cert_der == *client_root_cert_der {
info!("Client connection"); info!("Client connection");
let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec);
client_handler(transport, sm_addr, pdcm_addr).await;
} else { } else {
error!("Unknown Root Certificate"); error!("Unknown Root Certificate");
} }
@@ -270,3 +189,172 @@ async fn main() {
.await .await
.unwrap(); .unwrap();
} }
async fn server_subscribe_handler(transport: TransportStream, name: String, sm_addr: Addr<ServerManager>) {
match sm_addr.send(StartTransporter { name, transport }).await.unwrap() {
Ok(_) => {
info!("Stream sent to the manager");
}
Err(e) => {
error!("Error from manager: {:?}", e);
}
}
}
async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec<u8>, server_db_addr: Addr<ServerCertDB>) {
loop {
match transport.next().await {
None => {
info!("Transport returned None");
break;
}
Some(item) => match item {
Ok(buf) => {
use libbonknet::servermsg::FromServerCommandMessage::*;
let msg: FromServerCommandMessage = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
ChangeName { name } => {
info!("Changing name to {}", name);
match server_db_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() {
None => {
info!("Nothing to unregister");
}
Some(old_name) => {
warn!("Unregistered from old name {}", old_name);
}
}
let reply = match server_db_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() {
Ok(_) => {
info!("Registered!");
ToServerCommandReply::NameChanged
}
Err(e) => {
error!("Error registering: {:?}", e);
ToServerCommandReply::GenericFailure
}
};
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
WhoAmI => {
info!("Asked who I am");
let reply = match server_db_addr.send(FetchName { cert: peer_cert_bytes.clone() }).await.unwrap() {
None => {
info!("I'm not registered anymore!? WTF");
ToServerCommandReply::GenericFailure
}
Some(name) => {
info!("I am {}", name);
ToServerCommandReply::YouAre { name }
}
};
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
break;
}
}
}
}
}
// TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler
async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Addr<ServerCertDB>, server_root_cert: &Arc<Certificate>) {
loop {
match transport.next().await {
None => {
info!("Transport returned None");
break;
}
Some(item) => {
match item {
Ok(buf) => {
use libbonknet::servermsg::FromGuestServerMessage::*;
let msg: FromGuestServerMessage = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
Announce { name } => {
info!("Announced with name {}", name);
if server_db_addr.send(IsNameRegistered { name: name.clone() }).await.unwrap() {
info!("Name {} already registered!", name);
let reply = ToGuestServerMessage::FailedNameAlreadyOccupied;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
break; // Stop reading
} else {
let cert = generate_server_cert(server_root_cert, name.as_str());
server_db_addr.send(RegisterServer {
cert: cert.cert.clone(),
name,
}).await.unwrap().unwrap();
let reply = ToGuestServerMessage::OkAnnounce {
server_cert: cert.cert,
server_prkey: cert.prkey
};
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
break;
}
}
}
}
}
}
async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerManager>, pdcm_addr: Addr<PendingDataConnManager>) {
loop {
match transport.next().await {
None => {
info!("Transport returned None");
break;
}
Some(item) => {
match item {
Ok(buf) => {
let msg: FromClientCommand = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
FromClientCommand::RequestServer { name } => {
info!("REQUESTED SERVER {}", name);
let data = sm_addr.send(RequestServer { name }).await.unwrap();
match data {
Ok(client_conn_id) => {
let reply = ToClientResponse::OkRequest { conn_id: client_conn_id };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
info!("Sent OkRequest");
}
Err(e) => {
error!("Error! {:?}", e);
}
}
}
FromClientCommand::ServerList => {
info!("Requested ServerList");
let data = sm_addr.send(GetServerList {}).await.unwrap();
let reply = ToClientResponse::OkServerList { data };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
FromClientCommand::UpgradeToDataStream(conn_id) => {
info!("Upgrade to DataStream with conn_id {:?}", conn_id);
let msg = RegisterStream::client(conn_id, transport);
pdcm_addr.send(msg).await.unwrap().unwrap();
break;
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
break;
}
}
}
}
}
}

View File

@@ -0,0 +1,241 @@
use actix::prelude::*;
use actix_tls::accept::rustls_0_22::TlsStream;
use futures::SinkExt;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::{error, info};
use uuid::Uuid;
use libbonknet::*;
use crate::dataconnmanager::{DataConnManager, StartDataBridge};
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
/*
L'idea e' che il database deve avere una riga per ogni connessione dati in nascita.
In ogni "riga" deve essere presente:
- Il ServerConnID che il server dovra usare (pkey)
- Il ClientConnID che il client dovra usare (pkey)
- Se gia connesso, il Socket del Server
- Se gia connesso il Socket del Client
Quando in una riga sono presenti sia ServerSocket che ClientSocket allora poppa la riga
e usa i socket per lanciare un thread/actor che faccia il piping dei dati
Ricerca riga deve avvenire sia tramite ServerConnID che ClientConnID se essi sono diversi come pianifico
Quindi l'ideale e' non usare una collection ma andare direttamente di Vector!
*/
#[derive(Error, Debug)]
pub enum PendingDataConnError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
pub struct NewPendingConn {
pub server_conn_id: Uuid,
pub client_conn_id: Uuid,
}
#[derive(Debug)]
enum RegisterKind {
Server,
Client
}
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
pub struct RegisterStream {
kind: RegisterKind,
conn_id: Uuid,
transport: TransportStream,
}
impl RegisterStream {
pub fn server(conn_id: Uuid, transport: TransportStream) -> Self {
RegisterStream { kind: RegisterKind::Server, conn_id, transport }
}
pub fn client(conn_id: Uuid, transport: TransportStream) -> Self {
RegisterStream { kind: RegisterKind::Client, conn_id, transport }
}
}
struct SideRecord {
conn_id: Uuid,
transport: Option<TransportStream>,
}
struct Record {
server: SideRecord,
client: SideRecord,
}
impl Record {
fn new(server_conn_id: Uuid, client_conn_id: Uuid) -> Self {
let server = SideRecord { conn_id: server_conn_id, transport: None };
let client = SideRecord { conn_id: client_conn_id, transport: None };
Record { server, client }
}
}
// TODO: every 2 minutes verify the Records that have at least one stream invalidated and drop them
pub struct PendingDataConnManager {
db: Vec<Record>,
dcm_addr: Addr<DataConnManager>,
}
impl PendingDataConnManager {
pub fn new(dcm_addr: Addr<DataConnManager>) -> Self {
PendingDataConnManager { db: Vec::new(), dcm_addr }
}
fn retrieve_siderecord(&mut self, kind: &RegisterKind, conn_id: &Uuid) -> Option<&mut SideRecord> {
use RegisterKind::*;
let record = match match kind {
Server => self.db.iter_mut().find(|x| x.server.conn_id == *conn_id),
Client => self.db.iter_mut().find(|x| x.client.conn_id == *conn_id),
} {
None => return None,
Some(item) => item,
};
Some(match kind {
Server => &mut record.server,
Client => &mut record.client,
})
}
}
impl Actor for PendingDataConnManager {
type Context = Context<Self>;
}
impl Handler<NewPendingConn> for PendingDataConnManager {
type Result = Result<(), PendingDataConnError>;
fn handle(&mut self, msg: NewPendingConn, _ctx: &mut Self::Context) -> Self::Result {
let server_conn_id = msg.server_conn_id;
let client_conn_id = msg.client_conn_id;
let new_record = Record::new(server_conn_id, client_conn_id);
if self.db.iter().any(|x| {
x.server.conn_id == new_record.server.conn_id || x.client.conn_id == new_record.client.conn_id
}) {
error!("Conflicting UUIDs: server {} - client {}", server_conn_id, client_conn_id);
Err(PendingDataConnError::GenericFailure)
} else {
self.db.push(new_record);
Ok(())
}
}
}
/*
Esegui tutti i test normali in Sync.
Quando devi inviare il Accepted, notificati la cosa come AcceptStream con lo stream in suo possesso
ma stavolta con un ResponseActFuture.
Se c'e un fallimento, sposta il transport in un ctx::spawn che invii un FAILED.
Se tutto OK, checka DI NUOVO tutto, e se i check sono positivi, registra lo stream nell'Actor.
Per gestire lo Spawn di una connection, l'unica risposta e' gestire lo spawn connection come un
Message Handler a sua volta. Quando uno stream completa l'invio del suo Accepted, esso appare nel Record.
Quando il secondo stream arriva e completa il suo accepted, anch'esso viene registrato nel Record, quindi
siamo nella condizione di spawn, perche ci sono entrambi i transport nel Record.
Quindi se alla fine di un check and register ci sono entrambi gli stream, spostali entrambi fuori, droppa
il record e invia i due transport ad un terzo actor, su un altro Arbiter, che esegua il tokio::io::copy e
gestisca le connessioni aperte.
*/
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
struct TryStartDataStream {
kind: RegisterKind,
conn_id: Uuid,
}
impl Handler<TryStartDataStream> for PendingDataConnManager {
type Result = Result<(), PendingDataConnError>;
fn handle(&mut self, msg: TryStartDataStream, _ctx: &mut Self::Context) -> Self::Result {
use RegisterKind::*;
let idx = match msg.kind {
Server => self.db.iter().enumerate().find(|(_i, x)| x.server.conn_id == msg.conn_id),
Client => self.db.iter().enumerate().find(|(_i, x)| x.client.conn_id == msg.conn_id),
};
if let Some((idx, record)) = idx {
if record.client.transport.is_some() && record.server.transport.is_some() {
info!("Requesting Data Bridge for client_conn_id {}", record.client.conn_id);
// We can drop record and use idx to remove the record itself from the vector
// and then send it to another manager for the spawn of the real connection
// This remove is necessary to drop the &Record and take full ownership of it.
let real_record = self.db.remove(idx); // Safety: We are sure this idx is valid
let msg = StartDataBridge {
client_conn_id: real_record.client.conn_id,
server_transport: real_record.server.transport.unwrap(), // Safety: We are sure this is Some
client_transport: real_record.client.transport.unwrap(), // Safety: We are sure this is Some
};
self.dcm_addr.do_send(msg);
}
Ok(())
} else {
Err(PendingDataConnError::GenericFailure)
}
}
}
impl Handler<RegisterStream> for PendingDataConnManager {
type Result = ResponseActFuture<Self, Result<(), PendingDataConnError>>;
fn handle(&mut self, msg: RegisterStream, _ctx: &mut Self::Context) -> Self::Result {
let side_record = match self.retrieve_siderecord(&msg.kind, &msg.conn_id) {
None => {
error!("Found no connection with {:?} conn_id {:?}", msg.kind, msg.conn_id);
return Box::pin(fut::err(PendingDataConnError::GenericFailure));
}
Some(item) => item,
};
if side_record.transport.is_some() {
// TODO: It can be good to check if the connection is still open, if not, drop and use the new one.
error!("Connection already with a socket!");
Box::pin(fut::err(PendingDataConnError::GenericFailure))
} else {
// This Fut will send the Accepted and only then, register the transport stream
// in the Manager. If during the registration there are all the transport in places,
// you can start the datapiping
Box::pin(async move {
let mut transport = msg.transport;
let reply = ToPeerDataStream::OkDataStreamRequestAccepted;
let res = transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await;
(transport, res)
}.into_actor(self).map(move |(transport, res), a, c| {
match res {
Ok(_) => {
// TODO: to not do the check twice I can put a "lock variable" inside the record
// to prevent the put in Accept of another stream while we are waiting to send the
// accept message
let side_record = match a.retrieve_siderecord(&msg.kind, &msg.conn_id) {
None => {
error!("Found no connection with {:?} conn_id {:?}", msg.kind, msg.conn_id);
return Err(PendingDataConnError::GenericFailure);
}
Some(item) => item,
};
if side_record.transport.is_some() {
// TODO: It can be good to check if the connection is still open, if not, drop and use the new one.
error!("Connection already with a socket!");
return Err(PendingDataConnError::GenericFailure);
}
side_record.transport = Some(transport);
c.notify(TryStartDataStream { kind: msg.kind, conn_id: msg.conn_id });
Ok(())
}
Err(e) => {
error!("Error during OkDataStreamRequestAccepted sending: {:?}", e);
Err(PendingDataConnError::GenericFailure)
}
}
}))
}
}
}

View File

@@ -0,0 +1,93 @@
use std::collections::HashMap;
use actix::prelude::*;
use thiserror::Error;
// TODO: Probably it's better to remove the pub from inside the structs and impl a new() funct
#[derive(Error, Debug)]
pub enum DBError {
#[error("Certificate is already registered with name {0}")]
CertAlreadyRegistered(String),
// #[error("Generic Failure")]
// GenericFailure,
}
#[derive(Message)]
#[rtype(result = "bool")]
pub struct IsNameRegistered {
pub name: String,
}
#[derive(Message)]
#[rtype(result = "Result<(), DBError>")]
pub struct RegisterServer {
pub cert: Vec<u8>,
pub name: String,
}
#[derive(Message)]
#[rtype(result = "Option<String>")] // None if nothing to unregister, Some if unregistered
pub struct UnregisterServer {
pub cert: Vec<u8>,
}
#[derive(Message)]
#[rtype(result = "Option<String>")]
pub struct FetchName {
pub cert: Vec<u8>,
}
// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique
pub struct ServerCertDB {
db: HashMap<Vec<u8>, String>, // Cert to Name
}
impl ServerCertDB {
pub fn new() -> Self {
ServerCertDB { db: HashMap::new() }
}
}
impl Actor for ServerCertDB {
type Context = Context<Self>;
}
impl Handler<RegisterServer> for ServerCertDB {
type Result = Result<(), DBError>;
fn handle(&mut self, msg: RegisterServer, _ctx: &mut Self::Context) -> Self::Result {
match self.db.get(&msg.cert) {
None => {
self.db.insert(msg.cert, msg.name);
Ok(())
}
Some(name) => {
Err(DBError::CertAlreadyRegistered(name.clone()))
}
}
}
}
impl Handler<IsNameRegistered> for ServerCertDB {
type Result = bool;
fn handle(&mut self, msg: IsNameRegistered, _ctx: &mut Self::Context) -> Self::Result {
self.db.values().any(|x| *x == msg.name)
}
}
impl Handler<FetchName> for ServerCertDB {
type Result = Option<String>;
fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result {
self.db.get(&msg.cert).map(|s| s.to_owned())
}
}
impl Handler<UnregisterServer> for ServerCertDB {
type Result = Option<String>;
fn handle(&mut self, msg: UnregisterServer, _ctx: &mut Self::Context) -> Self::Result {
self.db.remove(&msg.cert)
}
}

View File

@@ -0,0 +1,290 @@
use std::collections::HashMap;
use std::io::Error;
use std::sync::Arc;
use std::time::{Duration, Instant};
use actix::prelude::*;
use rand::random;
use thiserror::Error;
use futures::SinkExt;
use tokio::sync::{Mutex, oneshot};
use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{debug, error, info};
use crate::{TransportStream, TransportStreamRx, TransportStreamTx};
use uuid::Uuid;
use libbonknet::servermsg::*;
use crate::pendingdataconndb::*;
#[derive(Error, Debug)]
pub enum SendMsgError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<FromServerReplyBody, SendMsgError>")]
struct SendMsg {
msg: ToServerMessageBody,
}
struct ServerTransporter {
rx: Option<TransportStreamRx>,
tx: Arc<Mutex<TransportStreamTx>>,
timeout: Option<SpawnHandle>,
reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>,
}
impl ServerTransporter {
fn new(transport: TransportStream) -> Self {
let internal = transport.into_inner();
let (srx, stx) = tokio::io::split(internal);
let codec = LengthDelimitedCodec::new();
let rx = FramedRead::new(srx, codec.clone());
let tx = FramedWrite::new(stx, codec.clone());
ServerTransporter {
rx: Some(rx),
tx: Arc::new(Mutex::new(tx)),
timeout: None,
reply_channels: HashMap::new(),
}
}
fn actor_close(&mut self, ctx: &mut Context<Self>) {
ctx.stop();
}
}
impl Actor for ServerTransporter {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Register Read Stream
let rx = self.rx.take().expect("Rx Stream not found");
ctx.add_stream(rx);
// Register Timeout task
self.timeout = Some(ctx.run_interval(Duration::from_secs(120), |s, c| {
s.actor_close(c)
}));
// Register Send Ping Task
ctx.run_interval(Duration::from_secs(60), |s, c| {
let msg = ToServerMessage::Ping;
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = Arc::clone(&s.tx);
c.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(s).map(|res, _a, ctx| {
match res {
Ok(_) => {
info!("Ping sent!");
}
Err(_) => {
ctx.stop();
}
}
}));
});
}
}
impl Handler<SendMsg> for ServerTransporter {
type Result = ResponseFuture<Result<FromServerReplyBody, SendMsgError>>;
fn handle(&mut self, msg: SendMsg, _ctx: &mut Self::Context) -> Self::Result {
let (reply_channel_tx, reply_channel_rx) = oneshot::channel();
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return Box::pin(fut::ready(Err(SendMsgError::GenericFailure)));
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, reply_channel_tx);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
Box::pin(async move {
arc_tx.lock().await.send(payload).await.map_err(|_e| SendMsgError::GenericFailure)?;
info!("msg sent");
let r = reply_channel_rx.await.unwrap();
info!("reply received");
Ok(r)
})
}
}
impl StreamHandler<Result<BytesMut, Error>> for ServerTransporter {
fn handle(&mut self, item: Result<BytesMut, Error>, ctx: &mut Self::Context) {
match item {
Ok(buf) => {
use libbonknet::servermsg::FromServerReply::*;
let msg: FromServerReply = rmp_serde::from_slice(&buf).unwrap();
match msg {
Pong => {
info!("Pong received");
if let Some(spawn_handle) = self.timeout {
ctx.cancel_future(spawn_handle);
} else {
error!("There were no spawn handle configured!");
}
self.timeout = Some(ctx.run_interval(Duration::from_secs(120), |s, c| {
s.actor_close(c)
}));
}
Msg { reply_id, body } => match self.reply_channels.remove(&reply_id) {
None => {}
Some(reply_tx) => {
if let Err(_e) = reply_tx.send(body) {
error!("Oneshot channel {} got invalidated! No reply sent.", reply_id);
}
}
}
}
}
Err(e) => {
error!("ERROR {:?}", e);
}
}
}
}
#[derive(Error, Debug)]
pub enum ServerManagerError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),ServerManagerError>")]
pub struct StartTransporter {
pub name: String,
pub transport: TransportStream,
}
#[derive(Message)]
#[rtype(result = "Vec<String>")]
pub struct GetServerList {}
#[derive(Message)]
#[rtype(result = "Result<Uuid,ServerManagerError>")] // TODO: Return Client ID with struct to give it a name
pub struct RequestServer {
pub name: String
}
pub struct ServerManager {
entries: Arc<Mutex<HashMap<String, Addr<ServerTransporter>>>>,
// Name -> Addr to Actor
pdcdb_addr: Addr<PendingDataConnManager>,
}
impl ServerManager {
pub fn new(pdcdb_addr: Addr<PendingDataConnManager>) -> Self {
ServerManager { entries: Arc::new(Mutex::new(HashMap::new())), pdcdb_addr }
}
}
impl Actor for ServerManager {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Remove all the ServerTransporters that are not running
// TODO: This is a "Pull" from entries, but we can have the entries that when killed tell
// the Manager making it more reactive!
ctx.run_interval(Duration::from_secs(10), |s, c| {
let start = Instant::now();
let entries = Arc::clone(&s.entries);
c.spawn(async move {
let mut entries_mg = entries.lock().await;
let mut keys_to_delete = vec![];
for (name, st_addr) in entries_mg.iter() {
if !st_addr.connected() {
keys_to_delete.push(name.clone())
}
}
for name in keys_to_delete {
entries_mg.remove(&name);
info!("Closed ServerTransporter {} for actor death", name);
}
debug!("Cleaned ServerManager in {:?}", Instant::now() - start);
}.into_actor(s));
});
}
}
impl Handler<StartTransporter> for ServerManager {
type Result = ResponseFuture<Result<(), ServerManagerError>>;
fn handle(&mut self, msg: StartTransporter, _ctx: &mut Self::Context) -> Self::Result {
let entries = Arc::clone(&self.entries);
Box::pin(async move {
let mut entries_mg = entries.lock().await;
if entries_mg.contains_key(&msg.name) {
error!("A server called {} is already connected!", msg.name);
return Err(ServerManagerError::GenericFailure);
}
let st_addr = ServerTransporter::new(msg.transport).start();
entries_mg.insert(msg.name, st_addr);
Ok(())
})
}
}
impl Handler<GetServerList> for ServerManager {
type Result = ResponseFuture<Vec<String>>;
fn handle(&mut self, _msg: GetServerList, _ctx: &mut Self::Context) -> Self::Result {
let entries = Arc::clone(&self.entries);
Box::pin(async move {
let entries_mg = entries.lock().await;
entries_mg.keys().cloned().collect()
})
}
}
impl Handler<RequestServer> for ServerManager {
type Result = ResponseFuture<Result<Uuid, ServerManagerError>>;
fn handle(&mut self, msg: RequestServer, _ctx: &mut Self::Context) -> Self::Result {
let name = msg.name.clone();
let entries = Arc::clone(&self.entries);
let pdcdb_addr = self.pdcdb_addr.clone();
Box::pin(async move {
let lock = entries.lock().await;
let sh_addr = match lock.get(&name) {
None => {
error!("Requested server {} that isn't registered", name);
return Err(ServerManagerError::GenericFailure);
}
Some(item) => item,
};
let server_conn_id = Uuid::new_v4();
let client_conn_id = Uuid::new_v4();
match pdcdb_addr.send(NewPendingConn { server_conn_id, client_conn_id }).await.unwrap() {
Ok(_) => {
let msg = ToServerMessageBody::Request { conn_id: server_conn_id };
match sh_addr.send(SendMsg { msg }).await.unwrap() {
Ok(reply) => match reply {
FromServerReplyBody::RequestAccepted => {
Ok(client_conn_id)
}
FromServerReplyBody::RequestFailed => {
error!("Request Failed!");
Err(ServerManagerError::GenericFailure)
}
FromServerReplyBody::Pong => unreachable!(),
}
Err(e) => {
panic!("ERROR: {:?}", e);
}
}
}
Err(_e) => Err(ServerManagerError::GenericFailure),
}
})
}
}

View File

@@ -7,10 +7,13 @@ edition = "2021"
[dependencies] [dependencies]
libbonknet = { path = "../libbonknet" } libbonknet = { path = "../libbonknet" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full", "tracing"] }
tokio-rustls = "0.25.0"
tokio-util = { version = "0.7.10", features = ["codec"] } tokio-util = { version = "0.7.10", features = ["codec"] }
futures = "0.3" futures = "0.3"
rcgen = "0.12.0"
tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0" rustls-pemfile = "2.0.0"
serde = { version = "1.0", features = ["derive"] }
rmp-serde = "1.1.2" rmp-serde = "1.1.2"
tracing = "0.1"
tracing-subscriber = "0.3"
uuid = { version = "1.7.0", features = ["serde"] }

View File

@@ -1,74 +1,176 @@
use std::io::{Error, ErrorKind};
use std::sync::Arc; use std::sync::Arc;
use futures::SinkExt; use std::time::Duration;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::rustls::pki_types::{ServerName}; use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tokio_util::codec::{Framed, LengthDelimitedCodec};
use serde::{Serialize, Deserialize}; use libbonknet::*;
use libbonknet::{load_cert, load_prkey}; use libbonknet::clientmsg::*;
use uuid::Uuid;
use tracing::{error, info};
#[derive(Debug, Serialize, Deserialize)]
enum ClientMessage { async fn datastream(tlsconfig: ClientConfig, conn_id: Uuid) -> std::io::Result<()> {
Response { status_code: u32, msg: Option<String> }, let connector = TlsConnector::from(Arc::new(tlsconfig.clone()));
Announce { name: String }, let dnsname = ServerName::try_from("localhost").unwrap();
Required { id: String }, let stream = TcpStream::connect("localhost:2541").await?;
NotRequired { id: String }, let stream = connector.connect(dnsname, stream).await?;
let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
let msg = FromClientCommand::UpgradeToDataStream(conn_id);
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToPeerDataStream::*;
let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkDataStreamRequestAccepted => {
info!("Data Stream Accepted. Waiting for Open...");
}
Refused => {
error!("Refused");
return Err(Error::new(ErrorKind::ConnectionRefused, "Refused"));
}
other => {
error!("Unexpected response: {:?}", other);
return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response"));
}
}
}
Err(e) => {
error!("Error: {:?}", e);
return Err(e);
}
}
}
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToPeerDataStream::*;
let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkDataStreamOpen => {
info!("Data Stream Open!. Connecting Streams.");
}
Revoked => {
error!("Data Stream Revoked!");
return Err(Error::new(ErrorKind::ConnectionAborted, "Revoked"));
}
Refused => {
error!("Refused");
return Err(Error::new(ErrorKind::ConnectionRefused, "Refused"));
}
other => {
error!("Unexpected response: {:?}", other);
return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response"));
}
}
}
Err(e) => {
error!("Error: {:?}", e);
return Err(e);
}
}
}
let (mut rx, mut tx) = tokio::io::split(transport.into_inner());
let mut stdout = tokio::io::stdout();
let mut stdin = tokio::io::stdin();
let stdout_task = async move {
match tokio::io::copy(&mut rx, &mut stdout).await {
Ok(bytes_copied) => info!("{bytes_copied}"),
Err(e) => error!("Error during copy: {e}"),
}
};
let stdin_task = async move {
match tokio::io::copy(&mut stdin, &mut tx).await {
Ok(bytes_copied) => info!("{bytes_copied}"),
Err(e) => error!("Error during copy: {e}"),
}
};
tokio::join!(stdout_task, stdin_task);
Ok(())
} }
// TODO: This is an old examples
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let client_name = "Polnareffland1"; // Tracing Subscriber
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber).unwrap();
// Root certs to verify the server is the right one // Root certs to verify the server is the right one
let mut server_root_cert_store = RootCertStore::empty(); let mut broker_root_cert_store = RootCertStore::empty();
let server_root_cert_der = load_cert("server_root_cert.pem").unwrap(); let broker_root_cert_der = load_cert("certs/broker_root_cert.pem").unwrap();
server_root_cert_store.add(server_root_cert_der).unwrap(); broker_root_cert_store.add(broker_root_cert_der).unwrap();
// Auth Cert to send the server who am I // Public CA for Clients
let root_client_cert = load_cert("client_root_cert.pem").unwrap(); let root_client_cert = load_cert("certs/client_root_cert.pem").unwrap();
let client_cert = load_cert("client_cert.pem").unwrap(); // My Client Certificate for authentication
let client_prkey = load_prkey("client_key.pem").unwrap(); let client_cert = load_cert("certs/client_cert.pem").unwrap();
let client_cert_prkey = load_prkey("certs/client_key.pem").unwrap();
// Load TLS Config // Load TLS Config
let tlsconfig = ClientConfig::builder() let tlsconfig = ClientConfig::builder()
.with_root_certificates(server_root_cert_store) .with_root_certificates(broker_root_cert_store.clone())
// .with_no_client_auth(); .with_client_auth_cert(vec![client_cert, root_client_cert], client_cert_prkey.into())
.with_client_auth_cert(vec![client_cert, root_client_cert], client_prkey.into())
.unwrap(); .unwrap();
let connector = TlsConnector::from(Arc::new(tlsconfig)); let connector = TlsConnector::from(Arc::new(tlsconfig.clone()));
let dnsname = ServerName::try_from("localhost").unwrap(); let dnsname = ServerName::try_from("localhost").unwrap();
let stream = TcpStream::connect("localhost:2541").await?; let stream = TcpStream::connect("localhost:2541").await?;
let stream = connector.connect(dnsname, stream).await?; let stream = connector.connect(dnsname, stream).await?;
let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
let msg = FromClientCommand::ServerList;
let msg1 = ClientMessage::Announce { name: client_name.into() }; transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
transport.send(rmp_serde::to_vec(&msg1).unwrap().into()).await.unwrap(); match transport.next().await {
for i in 0..10 { None => panic!("None in the transport"),
let msg = ClientMessage::Response { status_code: 100+i, msg: Some(format!("yay {}", i)) }; Some(item) => match item {
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); Ok(buf) => {
tokio::time::sleep(std::time::Duration::from_secs(1)).await; use libbonknet::clientmsg::ToClientResponse;
use libbonknet::clientmsg::ToClientResponse::*;
let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkServerList { data } => info!("{}", data.join("\n")),
GenericError => error!("Generic Error during remote command execution"),
others => {
panic!("Unexpected Message type: {:?}", others);
}
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
}
tokio::time::sleep(Duration::from_secs(5)).await;
let msg = FromClientCommand::RequestServer { name: "cicciopizza".into() };
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use libbonknet::clientmsg::ToClientResponse;
use libbonknet::clientmsg::ToClientResponse::*;
let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkRequest { conn_id } => {
info!("Received Client Connection ID: {:?}", conn_id);
datastream(tlsconfig, conn_id).await.unwrap();
}
GenericError => error!("Generic Error during remote command execution"),
others => {
panic!("Unexpected Message type: {:?}", others);
}
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
} }
// transport.for_each(|item| async move {
// let a: ClientMessage = rmp_serde::from_slice(&item.unwrap()).unwrap();
// println!("{:?}", a);
// }).await;
// let mut buf = vec![0;1024];
// let (mut rd,mut tx) = split(stream);
//
//
// tokio::spawn(async move {
// let mut stdout = tokio::io::stdout();
// tokio::io::copy(&mut rd, &mut stdout).await.unwrap();
// });
//
// let mut reader = tokio::io::BufReader::new(tokio::io::stdin()).lines();
//
// while let Some(line) = reader.next_line().await.unwrap() {
// tx.write_all(line.as_bytes()).await.unwrap();
// }
Ok(()) Ok(())
} }

View File

@@ -15,4 +15,5 @@ tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0" rustls-pemfile = "2.0.0"
rmp-serde = "1.1.2" rmp-serde = "1.1.2"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
uuid = { version = "1.7.0", features = ["serde"] }

View File

@@ -1,14 +1,90 @@
use std::io::{Error, ErrorKind};
use std::sync::Arc; use std::sync::Arc;
use futures::{StreamExt, SinkExt}; use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::rustls::pki_types::{ServerName, CertificateDer, PrivatePkcs8KeyDer}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer, ServerName};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tokio_util::codec::{Framed, LengthDelimitedCodec};
use libbonknet::*; use libbonknet::*;
use tracing::{info, error}; use libbonknet::servermsg::*;
use uuid::Uuid;
use tracing::{error, info};
async fn datastream(tlsconfig: Arc<ClientConfig>, conn_id: Uuid) -> std::io::Result<()> {
let connector = TlsConnector::from(tlsconfig);
let dnsname = ServerName::try_from("localhost").unwrap();
let stream = TcpStream::connect("localhost:2541").await?;
let stream = connector.connect(dnsname, stream).await?;
let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
let msg = FromServerConnTypeMessage::OpenDataStream(conn_id);
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToPeerDataStream::*;
let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkDataStreamRequestAccepted => {
info!("Data Stream Accepted. Waiting for Open...");
}
Refused => {
error!("Refused");
return Err(Error::new(ErrorKind::ConnectionRefused, "Refused"));
}
other => {
error!("Unexpected response: {:?}", other);
return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response"));
}
}
}
Err(e) => {
error!("Error: {:?}", e);
return Err(e);
}
}
}
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToPeerDataStream::*;
let msg: ToPeerDataStream = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkDataStreamOpen => {
info!("Data Stream Open!. Connecting Streams.");
}
Revoked => {
error!("Data Stream Revoked!");
return Err(Error::new(ErrorKind::ConnectionAborted, "Revoked"));
}
Refused => {
error!("Refused");
return Err(Error::new(ErrorKind::ConnectionRefused, "Refused"));
}
other => {
error!("Unexpected response: {:?}", other);
return Err(Error::new(ErrorKind::ConnectionRefused, "Unexpected response"));
}
}
}
Err(e) => {
error!("Error: {:?}", e);
return Err(e);
}
}
}
let (mut rx, mut tx) = tokio::io::split(transport.into_inner());
match tokio::io::copy(&mut rx, &mut tx).await {
Ok(bytes_copied) => info!("{bytes_copied}"),
Err(e) => error!("Error during copy: {e}"),
}
Ok(())
}
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
// Tracing Subscriber // Tracing Subscriber
@@ -44,11 +120,12 @@ async fn main() -> std::io::Result<()> {
let mut myserver_prkey: Option<PrivatePkcs8KeyDer> = None; let mut myserver_prkey: Option<PrivatePkcs8KeyDer> = None;
match transport.next().await { match transport.next().await {
None => { None => {
info!("None in the transport.next() ???"); panic!("None in the transport");
} }
Some(item) => match item { Some(item) => match item {
Ok(buf) => { Ok(buf) => {
use ToGuestServerMessage::*; use libbonknet::servermsg::{okannounce_to_cert, ToGuestServerMessage};
use libbonknet::servermsg::ToGuestServerMessage::*;
let msg: ToGuestServerMessage = rmp_serde::from_slice(&buf).unwrap(); let msg: ToGuestServerMessage = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg); info!("{:?}", msg);
match msg { match msg {
@@ -69,41 +146,159 @@ async fn main() -> std::io::Result<()> {
} }
} }
} }
transport.close().await.unwrap();
if let (Some(server_cert), Some(server_prkey)) = (myserver_cert, myserver_prkey) { if let (Some(server_cert), Some(server_prkey)) = (myserver_cert, myserver_prkey) {
let tlsconfig = ClientConfig::builder() let tlsconfig = Arc::new(ClientConfig::builder()
.with_root_certificates(broker_root_cert_store) .with_root_certificates(broker_root_cert_store)
.with_client_auth_cert(vec![server_cert, root_server_cert], server_prkey.into()) .with_client_auth_cert(vec![server_cert, root_server_cert], server_prkey.into())
.unwrap(); .unwrap());
let connector = TlsConnector::from(Arc::new(tlsconfig)); let connector = TlsConnector::from(Arc::clone(&tlsconfig));
let dnsname = ServerName::try_from("localhost").unwrap(); let dnsname = ServerName::try_from("localhost").unwrap();
let stream = TcpStream::connect("localhost:2541").await?; let stream = TcpStream::connect("localhost:2541").await?;
let stream = connector.connect(dnsname, stream).await?; let stream = connector.connect(dnsname, stream).await?;
let transport = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
transport.for_each(|item| async move { let msg = FromServerConnTypeMessage::SendCommand;
match item { transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => {
panic!("None in the transport");
}
Some(item) => match item {
Ok(buf) => { Ok(buf) => {
use ToServerMessage::*; use libbonknet::servermsg::ToServerConnTypeReply;
let msg: ToServerMessage = rmp_serde::from_slice(&buf).unwrap(); use libbonknet::servermsg::ToServerConnTypeReply::*;
let msg: ToServerConnTypeReply = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg { match msg {
Required { id } => { OkSendCommand => {
info!("I'm required with Connection ID {}", id); info!("Stream set in SendCommand mode");
} }
YouAre(name) => match name { GenericFailure => {
YouAreValues::Registered { name } => { panic!("Generic Failure during SendCommand");
info!("I am {}", name); }
} others => {
YouAreValues::NotRegistered => { panic!("Unexpected Message type: {:?}", others);
info!("I'm not registered");
}
} }
} }
} }
Err(e) => { Err(e) => {
error!("Error: {:?}", e); info!("Disconnection: {:?}", e);
} }
} }
}).await; }
// Begin WhoAmI
let msg = FromServerCommandMessage::WhoAmI;
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => {
panic!("None in the transport");
}
Some(item) => match item {
Ok(buf) => {
use libbonknet::servermsg::ToServerCommandReply;
use libbonknet::servermsg::ToServerCommandReply::*;
let msg: ToServerCommandReply = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
YouAre { name } => {
info!("I am {}", name);
}
GenericFailure => {
panic!("Generic failure during WhoAmI");
}
_ => {
panic!("Unexpected reply");
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
}
}
}
transport.close().await.expect("Error during transport stream close");
// Start Subscribe Stream
let connector = TlsConnector::from(Arc::clone(&tlsconfig));
let dnsname = ServerName::try_from("localhost").unwrap();
let stream = TcpStream::connect("localhost:2541").await?;
let stream = connector.connect(dnsname, stream).await?;
let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
let msg = FromServerConnTypeMessage::Subscribe;
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => {
panic!("None in the transport");
}
Some(item) => match item {
Ok(buf) => {
use libbonknet::servermsg::ToServerConnTypeReply;
use libbonknet::servermsg::ToServerConnTypeReply::*;
let msg: ToServerConnTypeReply = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
OkSubscribe => {
info!("Stream set in Subscribe mode");
}
GenericFailure => {
panic!("Generic Failure during SendCommand");
}
others => {
panic!("Unexpected Message type: {:?}", others);
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
}
}
}
// Subscribe consume
loop {
match transport.next().await {
None => {
info!("Empty Buffer");
}
Some(item) => {
let mut out: Option<FromServerReply> = None;
match item {
Ok(buf) => {
use libbonknet::servermsg::ToServerMessage;
use libbonknet::servermsg::ToServerMessage::*;
let msg: ToServerMessage = rmp_serde::from_slice(&buf).unwrap();
match msg {
Msg { reply_id, body } => {
use libbonknet::servermsg::FromServerReplyBody;
use libbonknet::servermsg::ToServerMessageBody::*;
match body {
Request { conn_id } => {
info!("I'm required with Connection ID {}", conn_id);
out = Some(FromServerReply::Msg {
reply_id,
body: FromServerReplyBody::RequestAccepted,
});
// TODO: SPAWN DATASTREAM
tokio::spawn(datastream(tlsconfig.clone(), conn_id));
}
}
}
Ping => {
info!("Ping!");
out = Some(FromServerReply::Pong);
}
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
if let Some(msg) = out {
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
}
}
}
}
} }
Ok(()) Ok(())
} }

View File

@@ -9,3 +9,4 @@ edition = "2021"
tokio-rustls = "0.25.0" tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0" rustls-pemfile = "2.0.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.7.0", features = ["serde"] }

View File

@@ -0,0 +1,23 @@
pub use crate::ToPeerDataStream;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
// Client things
#[derive(Debug, Serialize, Deserialize)]
pub enum FromClientCommand {
RequestServer { name: String },
ServerList,
UpgradeToDataStream(Uuid),
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToClientResponse {
OkRequest { conn_id: Uuid },
OkServerList { data: Vec<String> },
// You are now a DataStream, wait the Open message
OkDataStreamRequestAccepted,
// The stream is open, you can pipe in-out the content you want!
OkDataStreamOpen,
GenericError,
}

View File

@@ -1,11 +1,14 @@
pub mod servermsg;
pub mod clientmsg;
use std::io::{BufReader, Error, ErrorKind}; use std::io::{BufReader, Error, ErrorKind};
use rustls_pemfile::{read_one, Item}; use rustls_pemfile::{Item, read_one};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
pub fn load_cert(filename: &str) -> std::io::Result<CertificateDer> { pub fn load_cert(filename: &str) -> std::io::Result<CertificateDer> {
let cert_file = std::fs::File::open(filename).unwrap(); let cert_file = std::fs::File::open(filename).unwrap();
let mut buf = std::io::BufReader::new(cert_file); let mut buf = BufReader::new(cert_file);
if let Item::X509Certificate(cert) = read_one(&mut buf).unwrap().unwrap() { if let Item::X509Certificate(cert) = read_one(&mut buf).unwrap().unwrap() {
Ok(cert) Ok(cert)
} else { } else {
@@ -26,52 +29,12 @@ pub fn load_prkey(filename: &str) -> std::io::Result<PrivatePkcs8KeyDer> {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum RequiredReplyValues { pub enum ToPeerDataStream {
Ok, // You are now a DataStream, wait the Open message
GenericFailure { status_code: u32, msg: Option<String> }, OkDataStreamRequestAccepted,
} // The stream is open, you can pipe in-out the content you want!
OkDataStreamOpen,
#[derive(Debug, Serialize, Deserialize)] Refused,
pub enum FromServerMessage { Revoked,
RequiredReply(RequiredReplyValues), GenericError,
ChangeName { name: String },
WhoAmI,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum YouAreValues {
Registered { name: String },
NotRegistered,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToServerMessage {
Required { id: String },
YouAre(YouAreValues),
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FromGuestServerMessage {
Announce { name: String }
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToGuestServerMessage {
OkAnnounce {server_cert: Vec<u8>, server_prkey: Vec<u8>},
FailedNameAlreadyOccupied,
}
pub fn okannounce_to_cert<'a>(server_cert: Vec<u8>, server_prkey: Vec<u8>) -> (CertificateDer<'a>, PrivatePkcs8KeyDer<'a>) {
let server_cert = CertificateDer::from(server_cert);
let server_prkey = PrivatePkcs8KeyDer::from(server_prkey);
(server_cert, server_prkey)
}
impl ToGuestServerMessage {
pub fn make_okannounce(server_cert: CertificateDer, server_prkey: PrivatePkcs8KeyDer) -> Self {
ToGuestServerMessage::OkAnnounce{
server_cert: server_cert.to_vec(),
server_prkey: server_prkey.secret_pkcs8_der().to_vec()
}
}
} }

View File

@@ -0,0 +1,93 @@
pub use crate::ToPeerDataStream;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Serialize, Deserialize)]
pub enum FromServerConnTypeMessage {
SendCommand,
Subscribe,
OpenDataStream(Uuid),
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToServerConnTypeReply {
OkSendCommand,
OkSubscribe,
// You are now a DataStream, wait the Open message
OkDataStreamRequestAccepted,
// The stream is open, you can pipe in-out the content you want!
OkDataStreamOpen,
GenericFailure,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FromServerCommandMessage {
ChangeName { name: String },
WhoAmI,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToServerCommandReply {
NameChanged,
NameNotAvailable,
YouAre { name: String },
GenericFailure,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToServerMessageBody {
Request { conn_id: Uuid },
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToServerMessage {
Ping,
Msg {
reply_id: u64,
body: ToServerMessageBody,
},
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FromServerReplyBody {
RequestAccepted,
RequestFailed,
Pong,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FromServerReply {
Pong,
Msg {
reply_id: u64,
body: FromServerReplyBody
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum FromGuestServerMessage {
Announce { name: String }
}
pub fn okannounce_to_cert<'a>(server_cert: Vec<u8>, server_prkey: Vec<u8>) -> (CertificateDer<'a>, PrivatePkcs8KeyDer<'a>) {
let server_cert = CertificateDer::from(server_cert);
let server_prkey = PrivatePkcs8KeyDer::from(server_prkey);
(server_cert, server_prkey)
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToGuestServerMessage {
OkAnnounce { server_cert: Vec<u8>, server_prkey: Vec<u8> },
FailedNameAlreadyOccupied,
}
impl ToGuestServerMessage {
pub fn make_okannounce(server_cert: CertificateDer, server_prkey: PrivatePkcs8KeyDer) -> Self {
ToGuestServerMessage::OkAnnounce {
server_cert: server_cert.to_vec(),
server_prkey: server_prkey.secret_pkcs8_der().to_vec()
}
}
}