Implement the skeleton for the ServerManager and the spawn of the connection_ids

This commit is contained in:
2024-02-19 14:22:11 +01:00
parent f8feb9db81
commit 37cc133d7f
11 changed files with 699 additions and 1618 deletions

1404
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ serde = "1"
rmp-serde = "1.1.2" rmp-serde = "1.1.2"
rcgen = { version = "0.12.1", features = ["x509-parser"] } rcgen = { version = "0.12.1", features = ["x509-parser"] }
rand = "0.8.5" rand = "0.8.5"
uuid = { version = "1.7.0", features = ["v4", "serde"] }
[[bin]] [[bin]]
name = "init_certs" name = "init_certs"

View File

@@ -1,11 +1,14 @@
mod servercertdb; mod servercertdb;
mod pendingdataconndb;
mod servermanager;
use servercertdb::*; use servercertdb::*;
use pendingdataconndb::*;
use servermanager::*;
use actix::prelude::*; use actix::prelude::*;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc}; use std::sync::{Arc};
use std::time::{Instant, Duration}; use std::time::{Instant, Duration};
use actix::fut::wrap_future;
use libbonknet::*; use libbonknet::*;
use rustls::{RootCertStore, ServerConfig}; use rustls::{RootCertStore, ServerConfig};
use rustls::server::WebPkiClientVerifier; use rustls::server::WebPkiClientVerifier;
@@ -16,12 +19,14 @@ use actix_service::{ServiceFactoryExt as _};
use futures::{StreamExt, SinkExt}; use futures::{StreamExt, SinkExt};
use rand::random; use rand::random;
use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{info, error}; use tracing::{info, error, warn};
use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; use rcgen::{Certificate, CertificateParams, DnType, KeyPair};
use thiserror::Error;
use tokio::io::{ReadHalf, WriteHalf}; use tokio::io::{ReadHalf, WriteHalf};
use tokio_util::bytes::{Bytes, BytesMut}; use tokio_util::bytes::{Bytes, BytesMut};
use tokio::io::Error; use tokio::io::Error;
use tokio::sync::{oneshot, Mutex}; use tokio::sync::{oneshot, Mutex};
use uuid::Uuid;
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>; type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
type TransportStreamTx = FramedWrite<WriteHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>; type TransportStreamTx = FramedWrite<WriteHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
@@ -47,131 +52,6 @@ fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert {
} }
#[derive(MessageResponse)]
enum SendMsgResult {
Accepted,
Rejected,
}
#[derive(Message)]
#[rtype(result = "SendMsgResult")]
struct SendMsg {
msg: ToServerMessageBody,
reply_channel: oneshot::Sender<FromServerReplyBody>
}
#[derive(Message)]
#[rtype(result = "()")]
struct SendPing;
struct ServerTransporter {
rx: Option<TransportStreamRx>,
tx: Arc<Mutex<TransportStreamTx>>,
last_transmission: Instant,
reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>,
}
impl ServerTransporter {
fn new(transport: TransportStream) -> Self {
let internal = transport.into_inner();
let (srx, stx) = tokio::io::split(internal);
let codec = LengthDelimitedCodec::new();
let rx = FramedRead::new(srx, codec.clone());
let tx = FramedWrite::new(stx, codec.clone());
ServerTransporter {
rx: Some(rx),
tx: Arc::new(Mutex::new(tx)),
last_transmission: Instant::now(),
reply_channels: HashMap::new(),
}
}
}
impl Actor for ServerTransporter {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
let rx = self.rx.take().expect("Rx Stream not found");
ctx.add_stream(rx);
ctx.run_interval(Duration::from_secs(60), |_s, c| {
c.notify(SendPing);
});
}
}
impl Handler<SendPing> for ServerTransporter {
type Result = ();
fn handle(&mut self, _msg: SendPing, ctx: &mut Self::Context) -> Self::Result {
let msg = ToServerMessage::Ping;
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
ctx.spawn(wrap_future::<_, Self>(async move {
arc_tx.lock().await.send(payload).await
}).map(|res, _a, _ctx| {
info!("Ping sent result: {:?}", res);
}));
}
}
impl Handler<SendMsg> for ServerTransporter {
type Result = SendMsgResult;
fn handle(&mut self, msg: SendMsg, ctx: &mut Self::Context) -> Self::Result {
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return SendMsgResult::Rejected;
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, msg.reply_channel);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
ctx.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(self).map(|res, _a, _ctx| {
info!("ToServerMsg sent result: {:?}", res);
}));
SendMsgResult::Accepted
}
}
impl StreamHandler<Result<BytesMut, Error>> for ServerTransporter {
fn handle(&mut self, item: Result<BytesMut, Error>, _ctx: &mut Self::Context) {
match item {
Ok(buf) => {
use FromServerReply::*;
let msg: FromServerReply = rmp_serde::from_slice(&buf).unwrap();
match msg {
Pong => {
self.last_transmission = Instant::now();
}
Msg { reply_id, body } => match self.reply_channels.remove(&reply_id) {
None => {}
Some(reply_tx) => {
if let Err(_e) = reply_tx.send(body) {
error!("Oneshot channel {} got invalidated! No reply sent.", reply_id);
}
}
}
}
}
Err(e) => {
error!("ERROR {:?}", e);
}
}
}
}
#[actix_rt::main] #[actix_rt::main]
async fn main() { async fn main() {
// Tracing Subscriber // Tracing Subscriber
@@ -211,9 +91,9 @@ async fn main() {
server_root_prkey server_root_prkey
).unwrap()).unwrap()); ).unwrap()).unwrap());
let server_db_addr = ServerCertDB { let scdb_addr = ServerCertDB::new().start();
db: HashMap::new(), let pdcm_addr = PendingDataConnManager::new().start();
}.start(); let sm_addr = ServerManager::new(pdcm_addr.clone()).start();
Server::build() Server::build()
.bind("server-command", ("localhost", 2541), move || { .bind("server-command", ("localhost", 2541), move || {
@@ -221,7 +101,8 @@ async fn main() {
let client_root_cert_der = Arc::clone(&client_root_cert_der); let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let server_db_addr = server_db_addr.clone(); let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone();
// Set up TLS service factory // Set up TLS service factory
server_acceptor server_acceptor
@@ -232,11 +113,14 @@ async fn main() {
let client_root_cert_der = Arc::clone(&client_root_cert_der); let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let server_db_addr = server_db_addr.clone(); let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone();
async move { async move {
let peer_certs = stream.get_ref().1.peer_certificates().unwrap(); let peer_certs = stream.get_ref().1.peer_certificates().unwrap();
let peer_cert_bytes = peer_certs.first().unwrap().to_vec(); let peer_cert_bytes = peer_certs.first().unwrap().to_vec();
let peer_root_cert_der = peer_certs.last().unwrap().clone(); let peer_root_cert_der = peer_certs.last().unwrap().clone();
let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone();
if peer_root_cert_der == *server_root_cert_der { if peer_root_cert_der == *server_root_cert_der {
info!("Server connection"); info!("Server connection");
let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
@@ -254,15 +138,22 @@ async fn main() {
info!("SendCommand Stream"); info!("SendCommand Stream");
let reply = ToServerConnTypeReply::OkSendCommand; let reply = ToServerConnTypeReply::OkSendCommand;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
server_command_handler(transport, peer_cert_bytes, &server_db_addr).await; server_command_handler(transport, peer_cert_bytes, scdb_addr).await;
} }
Subscribe => { Subscribe => {
info!("Subscribe Stream"); info!("Subscribe Stream");
let name = match scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() {
None => {
error!("Cert has no name assigned!");
let reply = ToServerConnTypeReply::GenericFailure;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
return Ok(());
}
Some(name) => name,
};
let reply = ToServerConnTypeReply::OkSubscribe; let reply = ToServerConnTypeReply::OkSubscribe;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
// TODO: If I pass transport away and the service returns, what happen to the connection? server_subscribe_handler(transport, name, sm_addr).await;
// in theory it will remain open but better check
server_subscribe_handler(transport).await;
} }
} }
} }
@@ -277,9 +168,13 @@ async fn main() {
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let codec = LengthDelimitedCodec::new(); let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec); let transport = Framed::new(stream, codec);
guestserver_handler(transport, &server_db_addr, &server_root_cert).await; guestserver_handler(transport, scdb_addr, &server_root_cert).await;
} else if peer_root_cert_der == *client_root_cert_der { } else if peer_root_cert_der == *client_root_cert_der {
info!("Client connection"); info!("Client connection");
//pdcm_addr
let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec);
client_handler(transport, sm_addr).await;
} else { } else {
error!("Unknown Root Certificate"); error!("Unknown Root Certificate");
} }
@@ -293,22 +188,18 @@ async fn main() {
.unwrap(); .unwrap();
} }
async fn server_subscribe_handler(transport: TransportStream) { async fn server_subscribe_handler(transport: TransportStream, name: String, sm_addr: Addr<ServerManager>) {
let h = ServerTransporter::new(transport).start(); match sm_addr.send(StartTransporter { name, transport }).await.unwrap() {
info!("Actor spawned"); Ok(_) => {
tokio::time::sleep(Duration::from_secs(5)).await; info!("Stream sent to the manager");
info!("5 seconds elapsed, sending msg"); }
let (tx, rx) = oneshot::channel(); Err(e) => {
h.send(SendMsg { error!("Error from manager: {:?}", e);
msg: ToServerMessageBody::Required { id: "session_id".to_string() }, }
reply_channel: tx,
}).await.unwrap();
if let Ok(item) = rx.await {
info!("Response: {:?}", item);
} }
} }
async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec<u8>, server_db_addr: &Addr<ServerCertDB>) { async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec<u8>, server_db_addr: Addr<ServerCertDB>) {
loop { loop {
match transport.next().await { match transport.next().await {
None => { None => {
@@ -323,7 +214,25 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
match msg { match msg {
ChangeName { name } => { ChangeName { name } => {
info!("Changing name to {}", name); info!("Changing name to {}", name);
// TODO match server_db_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() {
None => {
info!("Nothing to unregister");
}
Some(old_name) => {
warn!("Unregistered from old name {}", old_name);
}
}
let reply = match server_db_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() {
Ok(_) => {
info!("Registered!");
ToServerCommandReply::NameChanged
}
Err(e) => {
error!("Error registering: {:?}", e);
ToServerCommandReply::GenericFailure
}
};
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
} }
WhoAmI => { WhoAmI => {
info!("Asked who I am"); info!("Asked who I am");
@@ -351,7 +260,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
} }
// TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler // TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler
async fn guestserver_handler(mut transport: TransportStream, server_db_addr: &Addr<ServerCertDB>, server_root_cert: &Arc<Certificate>) { async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Addr<ServerCertDB>, server_root_cert: &Arc<Certificate>) {
loop { loop {
match transport.next().await { match transport.next().await {
None => { None => {
@@ -396,3 +305,48 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: &Ad
} }
} }
} }
async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerManager>) {
loop {
match transport.next().await {
None => {
info!("Transport returned None");
break;
}
Some(item) => {
match item {
Ok(buf) => {
let msg: FromClientCommand = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
FromClientCommand::RequestServer { name } => {
info!("REQUESTED SERVER {}", name);
let data = sm_addr.send(RequestServer { name }).await.unwrap();
match data {
Ok(client_conn_id) => {
let reply = ToClientResponse::OkRequest { conn_id: client_conn_id };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
info!("Sent OkRequest");
}
Err(e) => {
error!("Error! {:?}", e);
}
}
}
FromClientCommand::ServerList => {
info!("Requested ServerList");
let data = sm_addr.send(GetServerList {}).await.unwrap();
let reply = ToClientResponse::OkServerList { data };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
break;
}
}
}
}
}
}

View File

@@ -0,0 +1,149 @@
use actix::prelude::*;
use actix_tls::accept::rustls_0_22::TlsStream;
use thiserror::Error;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{error, info};
use uuid::Uuid;
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
/*
L'idea e' che il database deve avere una riga per ogni connessione dati in nascita.
In ogni "riga" deve essere presente:
- Il ServerConnID che il server dovra usare (pkey)
- Il ClientConnID che il client dovra usare (pkey)
- Se gia connesso, il Socket del Server
- Se gia connesso il Socket del Client
Quando in una riga sono presenti sia ServerSocket che ClientSocket allora poppa la riga
e usa i socket per lanciare un thread/actor che faccia il piping dei dati
Ricerca riga deve avvenire sia tramite ServerConnID che ClientConnID se essi sono diversi come pianifico
Quindi l'ideale e' non usare una collection ma andare direttamente di Vector!
*/
#[derive(Error, Debug)]
pub enum PendingDataConnError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
pub struct NewPendingConn {
pub server_conn_id: Uuid,
pub client_conn_id: Uuid,
}
#[derive(Debug)]
enum RegisterKind {
Server,
Client
}
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
pub struct RegisterStream {
kind: RegisterKind,
conn_id: Uuid,
transport: TransportStream,
}
impl RegisterStream {
pub fn server(conn_id: Uuid, transport: TransportStream) -> Self {
RegisterStream { kind: RegisterKind::Server, conn_id, transport }
}
pub fn client(conn_id: Uuid, transport: TransportStream) -> Self {
RegisterStream { kind: RegisterKind::Client, conn_id, transport }
}
}
struct SideRecord {
conn_id: Uuid,
transport: Option<TransportStream>,
}
struct Record {
server: SideRecord,
client: SideRecord,
}
impl Record {
fn new(server_conn_id: Uuid, client_conn_id: Uuid) -> Self {
let server = SideRecord { conn_id: client_conn_id, transport: None };
let client = SideRecord { conn_id: server_conn_id, transport: None };
Record { server, client }
}
}
// TODO: every 2 minutes verify the Records that have at least one stream invalidated and drop them
pub struct PendingDataConnManager {
db: Vec<Record>,
}
impl PendingDataConnManager {
pub fn new() -> Self {
PendingDataConnManager { db: Vec::new() }
}
}
impl Actor for PendingDataConnManager {
type Context = Context<Self>;
}
impl Handler<NewPendingConn> for PendingDataConnManager {
type Result = Result<(), PendingDataConnError>;
fn handle(&mut self, msg: NewPendingConn, _ctx: &mut Self::Context) -> Self::Result {
let server_conn_id = msg.server_conn_id;
let client_conn_id = msg.client_conn_id;
let new_record = Record::new(server_conn_id, client_conn_id);
if self.db.iter().any(|x| {
x.server.conn_id == new_record.server.conn_id || x.client.conn_id == new_record.client.conn_id
}) {
error!("Conflicting UUIDs: server {} - client {}", server_conn_id, client_conn_id);
Err(PendingDataConnError::GenericFailure)
} else {
self.db.push(new_record);
Ok(())
}
}
}
impl Handler<RegisterStream> for PendingDataConnManager {
type Result = Result<(), PendingDataConnError>;
fn handle(&mut self, msg: RegisterStream, _ctx: &mut Self::Context) -> Self::Result {
use RegisterKind::*;
let conn_id = msg.conn_id;
let record = match match msg.kind {
Server => self.db.iter_mut().find(|x| x.server.conn_id == conn_id),
Client => self.db.iter_mut().find(|x| x.client.conn_id == conn_id),
} {
None => {
error!("Found no connection with {:?} conn_id {:?}", msg.kind, conn_id);
return Err(PendingDataConnError::GenericFailure);
},
Some(item) => item,
};
let side_record = match msg.kind {
Server => &mut record.server,
Client => &mut record.client,
};
if side_record.transport.is_some() {
// TODO: It can be good to check if the connection is still open, if not, drop and use the new one.
error!("Connection already with a socket!");
return Err(PendingDataConnError::GenericFailure);
} else {
side_record.transport = Some(msg.transport);
}
if record.client.transport.is_some() && record.server.transport.is_some() {
// TODO: Launch the "thing" that will manage the data mirroring
info!("LAUNCHING DATA MIRRORING");
}
Ok(())
}
}

View File

@@ -1,5 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use actix::{Actor, Context, Handler, Message}; use actix::prelude::*;
use thiserror::Error; use thiserror::Error;
// TODO: Probably it's better to remove the pub from inside the structs and impl a new() funct // TODO: Probably it's better to remove the pub from inside the structs and impl a new() funct
@@ -25,6 +25,12 @@ pub struct RegisterServer {
pub name: String, pub name: String,
} }
#[derive(Message)]
#[rtype(result = "Option<String>")] // None if nothing to unregister, Some if unregistered
pub struct UnregisterServer {
pub cert: Vec<u8>,
}
#[derive(Message)] #[derive(Message)]
#[rtype(result = "Option<String>")] #[rtype(result = "Option<String>")]
pub struct FetchName { pub struct FetchName {
@@ -33,7 +39,13 @@ pub struct FetchName {
// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique // TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique
pub struct ServerCertDB { pub struct ServerCertDB {
pub db: HashMap<Vec<u8>, String>, // Cert to Name db: HashMap<Vec<u8>, String>, // Cert to Name
}
impl ServerCertDB {
pub fn new() -> Self {
ServerCertDB { db: HashMap::new() }
}
} }
impl Actor for ServerCertDB { impl Actor for ServerCertDB {
@@ -70,4 +82,12 @@ impl Handler<FetchName> for ServerCertDB {
fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result { fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result {
self.db.get(&msg.cert).map(|s| s.to_owned()) self.db.get(&msg.cert).map(|s| s.to_owned())
} }
}
impl Handler<UnregisterServer> for ServerCertDB {
type Result = Option<String>;
fn handle(&mut self, msg: UnregisterServer, _ctx: &mut Self::Context) -> Self::Result {
self.db.remove(&msg.cert)
}
} }

View File

@@ -0,0 +1,334 @@
use std::collections::HashMap;
use std::io::Error;
use std::sync::{Arc};
use std::time::{Duration, Instant};
use actix::prelude::*;
use actix_server::Server;
use rand::random;
use thiserror::Error;
use futures::{StreamExt, SinkExt};
use tokio::sync::{Mutex, oneshot};
use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{debug, error, info};
use libbonknet::*;
use crate::{TransportStream, TransportStreamRx, TransportStreamTx};
use uuid::Uuid;
use crate::pendingdataconndb::*;
#[derive(MessageResponse)]
enum SendMsgResult {
Accepted,
Rejected,
}
#[derive(Message)]
#[rtype(result = "SendMsgResult")]
struct SendMsg {
msg: ToServerMessageBody,
reply_channel: oneshot::Sender<FromServerReplyBody>
}
#[derive(Error, Debug)]
pub enum AsyncSendMsgError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<FromServerReplyBody, AsyncSendMsgError>")]
struct AsyncSendMsg {
msg: ToServerMessageBody,
}
struct ServerTransporter {
rx: Option<TransportStreamRx>,
tx: Arc<tokio::sync::Mutex<TransportStreamTx>>,
timeout: Option<SpawnHandle>,
reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>,
}
impl Handler<AsyncSendMsg> for ServerTransporter {
type Result = ResponseFuture<Result<FromServerReplyBody, AsyncSendMsgError>>;
fn handle(&mut self, msg: AsyncSendMsg, _ctx: &mut Self::Context) -> Self::Result {
let (reply_channel_tx, reply_channel_rx) = oneshot::channel();
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return Box::pin(fut::ready(Err(AsyncSendMsgError::GenericFailure)));
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, reply_channel_tx);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
Box::pin(async move {
arc_tx.lock().await.send(payload).await.map_err(|e| AsyncSendMsgError::GenericFailure)?;
info!("msg sent");
let r = reply_channel_rx.await.unwrap();
info!("reply received");
Ok(r)
})
}
}
impl ServerTransporter {
fn new(transport: TransportStream) -> Self {
let internal = transport.into_inner();
let (srx, stx) = tokio::io::split(internal);
let codec = LengthDelimitedCodec::new();
let rx = FramedRead::new(srx, codec.clone());
let tx = FramedWrite::new(stx, codec.clone());
ServerTransporter {
rx: Some(rx),
tx: Arc::new(tokio::sync::Mutex::new(tx)),
timeout: None,
reply_channels: HashMap::new(),
}
}
fn actor_close(&mut self, ctx: &mut Context<Self>) {
ctx.stop();
}
}
impl Actor for ServerTransporter {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Register Read Stream
let rx = self.rx.take().expect("Rx Stream not found");
ctx.add_stream(rx);
// Register Timeout task
self.timeout = Some(ctx.run_interval(Duration::from_secs(120), |s, c| {
s.actor_close(c)
}));
// Register Send Ping Task
ctx.run_interval(Duration::from_secs(60), |s, c| {
let msg = ToServerMessage::Ping;
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = Arc::clone(&s.tx);
c.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(s).map(|res, a, ctx| {
match res {
Ok(_) => {
info!("Ping sent!");
}
Err(_) => {
ctx.stop();
}
}
}));
});
}
}
impl Handler<SendMsg> for ServerTransporter {
type Result = SendMsgResult;
fn handle(&mut self, msg: SendMsg, ctx: &mut Self::Context) -> Self::Result {
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return SendMsgResult::Rejected;
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, msg.reply_channel);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
ctx.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(self).map(|res, _a, _ctx| {
info!("ToServerMsg sent result: {:?}", res);
}));
SendMsgResult::Accepted
}
}
impl StreamHandler<Result<BytesMut, Error>> for ServerTransporter {
fn handle(&mut self, item: Result<BytesMut, Error>, ctx: &mut Self::Context) {
match item {
Ok(buf) => {
use FromServerReply::*;
let msg: FromServerReply = rmp_serde::from_slice(&buf).unwrap();
match msg {
Pong => {
info!("Pong received");
if let Some(spawn_handle) = self.timeout {
ctx.cancel_future(spawn_handle);
} else {
error!("There were no spawn handle configured!");
}
self.timeout = Some(ctx.run_interval(Duration::from_secs(120), |s, c| {
s.actor_close(c)
}));
}
Msg { reply_id, body } => match self.reply_channels.remove(&reply_id) {
None => {}
Some(reply_tx) => {
if let Err(_e) = reply_tx.send(body) {
error!("Oneshot channel {} got invalidated! No reply sent.", reply_id);
}
}
}
}
}
Err(e) => {
error!("ERROR {:?}", e);
}
}
}
}
#[derive(Error, Debug)]
pub enum ServerManagerError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),ServerManagerError>")]
pub struct StartTransporter {
pub name: String,
pub transport: TransportStream,
}
#[derive(Message)]
#[rtype(result = "Vec<String>")]
pub struct GetServerList {}
#[derive(Message)]
#[rtype(result = "Result<Uuid,ServerManagerError>")] // TODO: Return Client ID with struct to give it a name
pub struct RequestServer {
pub name: String
}
pub struct ServerManager {
entries: Arc<Mutex<HashMap<String, Addr<ServerTransporter>>>>,
// Name -> Addr to Actor
pdcdb_addr: Addr<PendingDataConnManager>,
}
impl ServerManager {
pub fn new(pdcdb_addr: Addr<PendingDataConnManager>) -> Self {
ServerManager { entries: Arc::new(Mutex::new(HashMap::new())), pdcdb_addr }
}
}
impl Actor for ServerManager {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Remove all the ServerTransporters that are not running
// TODO: This is a "Pull" from entries, but we can have the entries that when killed tell
// the Manager making it more reactive!
ctx.run_interval(Duration::from_secs(10), |s, c| {
let start = Instant::now();
let entries = Arc::clone(&s.entries);
c.spawn(async move {
let mut entries_mg = entries.lock().await;
let mut keys_to_delete = vec![];
for (name, st_addr) in entries_mg.iter() {
if !st_addr.connected() {
keys_to_delete.push(name.clone())
}
}
for name in keys_to_delete {
entries_mg.remove(&name);
info!("Closed ServerTransporter {} for actor death", name);
}
debug!("Cleaned ServerManager in {:?}", Instant::now() - start);
}.into_actor(s));
});
}
}
impl Handler<StartTransporter> for ServerManager {
type Result = ResponseFuture<Result<(), ServerManagerError>>;
fn handle(&mut self, msg: StartTransporter, _ctx: &mut Self::Context) -> Self::Result {
let mut entries = Arc::clone(&self.entries);
Box::pin(async move {
let mut entries_mg = entries.lock().await;
if entries_mg.contains_key(&msg.name) {
error!("A server called {} is already connected!", msg.name);
return Err(ServerManagerError::GenericFailure);
}
let st_addr = ServerTransporter::new(msg.transport).start();
entries_mg.insert(msg.name, st_addr);
Ok(())
})
}
}
impl Handler<GetServerList> for ServerManager {
type Result = ResponseFuture<Vec<String>>;
fn handle(&mut self, _msg: GetServerList, _ctx: &mut Self::Context) -> Self::Result {
let entries = Arc::clone(&self.entries);
Box::pin(async move {
let entries_mg = entries.lock().await;
entries_mg.keys().cloned().collect()
})
}
}
impl Handler<RequestServer> for ServerManager {
type Result = ResponseFuture<Result<Uuid, ServerManagerError>>;
fn handle(&mut self, msg: RequestServer, _ctx: &mut Self::Context) -> Self::Result {
let name = msg.name.clone();
let entries = Arc::clone(&self.entries);
let pdcdb_addr = self.pdcdb_addr.clone();
Box::pin(async move {
let lock = entries.lock().await;
let sh_addr = match lock.get(&name) {
None => {
error!("Requested server {} that isn't registered", name);
return Err(ServerManagerError::GenericFailure);
}
Some(item) => item,
};
let server_conn_id = Uuid::new_v4();
let client_conn_id = Uuid::new_v4();
match pdcdb_addr.send(NewPendingConn { server_conn_id, client_conn_id }).await.unwrap() {
Ok(_) => {
let msg = ToServerMessageBody::Request { conn_id: server_conn_id };
match sh_addr.send(AsyncSendMsg { msg }).await.unwrap() {
Ok(reply) => match reply {
FromServerReplyBody::RequestAccepted => {
Ok(client_conn_id)
}
FromServerReplyBody::RequestFailed => {
error!("Request Failed!");
Err(ServerManagerError::GenericFailure)
}
FromServerReplyBody::Pong => unreachable!(),
}
Err(e) => {
panic!("ERROR: {:?}", e);
}
}
}
Err(_e) => Err(ServerManagerError::GenericFailure),
}
})
}
}

View File

@@ -7,10 +7,12 @@ edition = "2021"
[dependencies] [dependencies]
libbonknet = { path = "../libbonknet" } libbonknet = { path = "../libbonknet" }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full", "tracing"] }
tokio-rustls = "0.25.0"
tokio-util = { version = "0.7.10", features = ["codec"] } tokio-util = { version = "0.7.10", features = ["codec"] }
futures = "0.3" futures = "0.3"
rcgen = "0.12.0"
tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0" rustls-pemfile = "2.0.0"
serde = { version = "1.0", features = ["derive"] }
rmp-serde = "1.1.2" rmp-serde = "1.1.2"
tracing = "0.1"
tracing-subscriber = "0.3"

View File

@@ -1,38 +1,35 @@
use std::io::Error;
use std::sync::Arc; use std::sync::Arc;
use futures::SinkExt; use std::time::Duration;
use futures::{StreamExt, SinkExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::rustls::pki_types::{ServerName}; use tokio_rustls::rustls::pki_types::{ServerName, CertificateDer, PrivatePkcs8KeyDer};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use tokio_util::bytes::BytesMut;
use tokio_util::codec::{Framed, LengthDelimitedCodec}; use tokio_util::codec::{Framed, LengthDelimitedCodec};
use serde::{Serialize, Deserialize}; use libbonknet::*;
use libbonknet::{load_cert, load_prkey}; use tracing::{info, error};
#[derive(Debug, Serialize, Deserialize)]
enum ClientMessage {
Response { status_code: u32, msg: Option<String> },
Announce { name: String },
Required { id: String },
NotRequired { id: String },
}
// TODO: This is an old examples
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let client_name = "Polnareffland1"; // Tracing Subscriber
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber).unwrap();
// Root certs to verify the server is the right one // Root certs to verify the server is the right one
let mut server_root_cert_store = RootCertStore::empty(); let mut broker_root_cert_store = RootCertStore::empty();
let server_root_cert_der = load_cert("server_root_cert.pem").unwrap(); let broker_root_cert_der = load_cert("certs/broker_root_cert.pem").unwrap();
server_root_cert_store.add(server_root_cert_der).unwrap(); broker_root_cert_store.add(broker_root_cert_der).unwrap();
// Auth Cert to send the server who am I // Public CA for Clients
let root_client_cert = load_cert("client_root_cert.pem").unwrap(); let root_client_cert = load_cert("certs/client_root_cert.pem").unwrap();
let client_cert = load_cert("client_cert.pem").unwrap(); // My Client Certificate for authentication
let client_prkey = load_prkey("client_key.pem").unwrap(); let client_cert = load_cert("certs/client_cert.pem").unwrap();
let client_cert_prkey = load_prkey("certs/client_key.pem").unwrap();
// Load TLS Config // Load TLS Config
let tlsconfig = ClientConfig::builder() let tlsconfig = ClientConfig::builder()
.with_root_certificates(server_root_cert_store) .with_root_certificates(broker_root_cert_store.clone())
// .with_no_client_auth(); .with_client_auth_cert(vec![client_cert, root_client_cert], client_cert_prkey.into())
.with_client_auth_cert(vec![client_cert, root_client_cert], client_prkey.into())
.unwrap(); .unwrap();
let connector = TlsConnector::from(Arc::new(tlsconfig)); let connector = TlsConnector::from(Arc::new(tlsconfig));
let dnsname = ServerName::try_from("localhost").unwrap(); let dnsname = ServerName::try_from("localhost").unwrap();
@@ -41,34 +38,46 @@ async fn main() -> std::io::Result<()> {
let stream = connector.connect(dnsname, stream).await?; let stream = connector.connect(dnsname, stream).await?;
let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
let msg = FromClientCommand::ServerList;
let msg1 = ClientMessage::Announce { name: client_name.into() }; transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
transport.send(rmp_serde::to_vec(&msg1).unwrap().into()).await.unwrap(); match transport.next().await {
for i in 0..10 { None => panic!("None in the transport"),
let msg = ClientMessage::Response { status_code: 100+i, msg: Some(format!("yay {}", i)) }; Some(item) => match item {
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); Ok(buf) => {
tokio::time::sleep(std::time::Duration::from_secs(1)).await; use ToClientResponse::*;
let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkRequest { .. } => error!("Wrong reply!"),
OkServerList { data } => info!("{}", data.join("\n")),
GenericError => error!("Generic Error during remote command execution"),
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
}
tokio::time::sleep(Duration::from_secs(5)).await;
let msg = FromClientCommand::RequestServer { name: "cicciopizza".into() };
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToClientResponse::*;
let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkRequest { conn_id } => {
info!("Received Client Connection ID: {:?}", conn_id);
}
OkServerList { .. } => error!("Wrong reply!"),
GenericError => error!("Generic Error during remote command execution"),
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
} }
// transport.for_each(|item| async move {
// let a: ClientMessage = rmp_serde::from_slice(&item.unwrap()).unwrap();
// println!("{:?}", a);
// }).await;
// let mut buf = vec![0;1024];
// let (mut rd,mut tx) = split(stream);
//
//
// tokio::spawn(async move {
// let mut stdout = tokio::io::stdout();
// tokio::io::copy(&mut rd, &mut stdout).await.unwrap();
// });
//
// let mut reader = tokio::io::BufReader::new(tokio::io::stdin()).lines();
//
// while let Some(line) = reader.next_line().await.unwrap() {
// tx.write_all(line.as_bytes()).await.unwrap();
// }
Ok(()) Ok(())
} }

View File

@@ -190,11 +190,11 @@ async fn main() -> std::io::Result<()> {
Msg { reply_id, body } => { Msg { reply_id, body } => {
use ToServerMessageBody::*; use ToServerMessageBody::*;
match body { match body {
Required { id } => { Request { conn_id } => {
info!("I'm required with Connection ID {}", id); info!("I'm required with Connection ID {}", conn_id);
out = Some(FromServerReply::Msg { out = Some(FromServerReply::Msg {
reply_id, reply_id,
body: FromServerReplyBody::RequiredAccepted, body: FromServerReplyBody::RequestAccepted,
}) })
} }
} }

View File

@@ -9,3 +9,4 @@ edition = "2021"
tokio-rustls = "0.25.0" tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0" rustls-pemfile = "2.0.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.7.0", features = ["serde"] }

View File

@@ -2,6 +2,7 @@ use std::io::{BufReader, Error, ErrorKind};
use rustls_pemfile::{read_one, Item}; use rustls_pemfile::{read_one, Item};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use uuid::Uuid;
pub fn load_cert(filename: &str) -> std::io::Result<CertificateDer> { pub fn load_cert(filename: &str) -> std::io::Result<CertificateDer> {
let cert_file = std::fs::File::open(filename).unwrap(); let cert_file = std::fs::File::open(filename).unwrap();
@@ -60,7 +61,7 @@ pub enum YouAreValues {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum ToServerMessageBody { pub enum ToServerMessageBody {
Required { id: String }, Request { conn_id: Uuid },
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@@ -74,8 +75,8 @@ pub enum ToServerMessage {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum FromServerReplyBody { pub enum FromServerReplyBody {
RequiredAccepted, RequestAccepted,
RequiredFailed, RequestFailed,
Pong, Pong,
} }
@@ -113,3 +114,17 @@ impl ToGuestServerMessage {
} }
} }
} }
// Client things
#[derive(Debug, Serialize, Deserialize)]
pub enum FromClientCommand {
RequestServer { name: String },
ServerList,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToClientResponse {
OkRequest { conn_id: Uuid },
OkServerList { data: Vec<String> },
GenericError,
}