Implement the skeleton for the ServerManager and the spawn of the connection_ids

This commit is contained in:
2024-02-19 14:22:11 +01:00
parent f8feb9db81
commit 37cc133d7f
11 changed files with 699 additions and 1618 deletions

1404
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,6 +24,7 @@ serde = "1"
rmp-serde = "1.1.2"
rcgen = { version = "0.12.1", features = ["x509-parser"] }
rand = "0.8.5"
uuid = { version = "1.7.0", features = ["v4", "serde"] }
[[bin]]
name = "init_certs"

View File

@@ -1,11 +1,14 @@
mod servercertdb;
mod pendingdataconndb;
mod servermanager;
use servercertdb::*;
use pendingdataconndb::*;
use servermanager::*;
use actix::prelude::*;
use std::collections::HashMap;
use std::sync::{Arc};
use std::time::{Instant, Duration};
use actix::fut::wrap_future;
use libbonknet::*;
use rustls::{RootCertStore, ServerConfig};
use rustls::server::WebPkiClientVerifier;
@@ -16,12 +19,14 @@ use actix_service::{ServiceFactoryExt as _};
use futures::{StreamExt, SinkExt};
use rand::random;
use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{info, error};
use tracing::{info, error, warn};
use rcgen::{Certificate, CertificateParams, DnType, KeyPair};
use thiserror::Error;
use tokio::io::{ReadHalf, WriteHalf};
use tokio_util::bytes::{Bytes, BytesMut};
use tokio::io::Error;
use tokio::sync::{oneshot, Mutex};
use uuid::Uuid;
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
type TransportStreamTx = FramedWrite<WriteHalf<TlsStream<TcpStream>>, LengthDelimitedCodec>;
@@ -47,131 +52,6 @@ fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert {
}
#[derive(MessageResponse)]
enum SendMsgResult {
Accepted,
Rejected,
}
#[derive(Message)]
#[rtype(result = "SendMsgResult")]
struct SendMsg {
msg: ToServerMessageBody,
reply_channel: oneshot::Sender<FromServerReplyBody>
}
#[derive(Message)]
#[rtype(result = "()")]
struct SendPing;
struct ServerTransporter {
rx: Option<TransportStreamRx>,
tx: Arc<Mutex<TransportStreamTx>>,
last_transmission: Instant,
reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>,
}
impl ServerTransporter {
fn new(transport: TransportStream) -> Self {
let internal = transport.into_inner();
let (srx, stx) = tokio::io::split(internal);
let codec = LengthDelimitedCodec::new();
let rx = FramedRead::new(srx, codec.clone());
let tx = FramedWrite::new(stx, codec.clone());
ServerTransporter {
rx: Some(rx),
tx: Arc::new(Mutex::new(tx)),
last_transmission: Instant::now(),
reply_channels: HashMap::new(),
}
}
}
impl Actor for ServerTransporter {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
let rx = self.rx.take().expect("Rx Stream not found");
ctx.add_stream(rx);
ctx.run_interval(Duration::from_secs(60), |_s, c| {
c.notify(SendPing);
});
}
}
impl Handler<SendPing> for ServerTransporter {
type Result = ();
fn handle(&mut self, _msg: SendPing, ctx: &mut Self::Context) -> Self::Result {
let msg = ToServerMessage::Ping;
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
ctx.spawn(wrap_future::<_, Self>(async move {
arc_tx.lock().await.send(payload).await
}).map(|res, _a, _ctx| {
info!("Ping sent result: {:?}", res);
}));
}
}
impl Handler<SendMsg> for ServerTransporter {
type Result = SendMsgResult;
fn handle(&mut self, msg: SendMsg, ctx: &mut Self::Context) -> Self::Result {
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return SendMsgResult::Rejected;
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, msg.reply_channel);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
ctx.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(self).map(|res, _a, _ctx| {
info!("ToServerMsg sent result: {:?}", res);
}));
SendMsgResult::Accepted
}
}
impl StreamHandler<Result<BytesMut, Error>> for ServerTransporter {
fn handle(&mut self, item: Result<BytesMut, Error>, _ctx: &mut Self::Context) {
match item {
Ok(buf) => {
use FromServerReply::*;
let msg: FromServerReply = rmp_serde::from_slice(&buf).unwrap();
match msg {
Pong => {
self.last_transmission = Instant::now();
}
Msg { reply_id, body } => match self.reply_channels.remove(&reply_id) {
None => {}
Some(reply_tx) => {
if let Err(_e) = reply_tx.send(body) {
error!("Oneshot channel {} got invalidated! No reply sent.", reply_id);
}
}
}
}
}
Err(e) => {
error!("ERROR {:?}", e);
}
}
}
}
#[actix_rt::main]
async fn main() {
// Tracing Subscriber
@@ -211,9 +91,9 @@ async fn main() {
server_root_prkey
).unwrap()).unwrap());
let server_db_addr = ServerCertDB {
db: HashMap::new(),
}.start();
let scdb_addr = ServerCertDB::new().start();
let pdcm_addr = PendingDataConnManager::new().start();
let sm_addr = ServerManager::new(pdcm_addr.clone()).start();
Server::build()
.bind("server-command", ("localhost", 2541), move || {
@@ -221,7 +101,8 @@ async fn main() {
let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert);
let server_db_addr = server_db_addr.clone();
let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone();
// Set up TLS service factory
server_acceptor
@@ -232,11 +113,14 @@ async fn main() {
let client_root_cert_der = Arc::clone(&client_root_cert_der);
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert);
let server_db_addr = server_db_addr.clone();
let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone();
async move {
let peer_certs = stream.get_ref().1.peer_certificates().unwrap();
let peer_cert_bytes = peer_certs.first().unwrap().to_vec();
let peer_root_cert_der = peer_certs.last().unwrap().clone();
let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone();
if peer_root_cert_der == *server_root_cert_der {
info!("Server connection");
let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
@@ -254,15 +138,22 @@ async fn main() {
info!("SendCommand Stream");
let reply = ToServerConnTypeReply::OkSendCommand;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
server_command_handler(transport, peer_cert_bytes, &server_db_addr).await;
server_command_handler(transport, peer_cert_bytes, scdb_addr).await;
}
Subscribe => {
info!("Subscribe Stream");
let name = match scdb_addr.send(FetchName { cert: peer_cert_bytes }).await.unwrap() {
None => {
error!("Cert has no name assigned!");
let reply = ToServerConnTypeReply::GenericFailure;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
return Ok(());
}
Some(name) => name,
};
let reply = ToServerConnTypeReply::OkSubscribe;
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
// TODO: If I pass transport away and the service returns, what happen to the connection?
// in theory it will remain open but better check
server_subscribe_handler(transport).await;
server_subscribe_handler(transport, name, sm_addr).await;
}
}
}
@@ -277,9 +168,13 @@ async fn main() {
let server_root_cert = Arc::clone(&server_root_cert);
let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec);
guestserver_handler(transport, &server_db_addr, &server_root_cert).await;
guestserver_handler(transport, scdb_addr, &server_root_cert).await;
} else if peer_root_cert_der == *client_root_cert_der {
info!("Client connection");
//pdcm_addr
let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec);
client_handler(transport, sm_addr).await;
} else {
error!("Unknown Root Certificate");
}
@@ -293,22 +188,18 @@ async fn main() {
.unwrap();
}
async fn server_subscribe_handler(transport: TransportStream) {
let h = ServerTransporter::new(transport).start();
info!("Actor spawned");
tokio::time::sleep(Duration::from_secs(5)).await;
info!("5 seconds elapsed, sending msg");
let (tx, rx) = oneshot::channel();
h.send(SendMsg {
msg: ToServerMessageBody::Required { id: "session_id".to_string() },
reply_channel: tx,
}).await.unwrap();
if let Ok(item) = rx.await {
info!("Response: {:?}", item);
async fn server_subscribe_handler(transport: TransportStream, name: String, sm_addr: Addr<ServerManager>) {
match sm_addr.send(StartTransporter { name, transport }).await.unwrap() {
Ok(_) => {
info!("Stream sent to the manager");
}
Err(e) => {
error!("Error from manager: {:?}", e);
}
}
}
async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec<u8>, server_db_addr: &Addr<ServerCertDB>) {
async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec<u8>, server_db_addr: Addr<ServerCertDB>) {
loop {
match transport.next().await {
None => {
@@ -323,7 +214,25 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
match msg {
ChangeName { name } => {
info!("Changing name to {}", name);
// TODO
match server_db_addr.send(UnregisterServer { cert: peer_cert_bytes.clone() }).await.unwrap() {
None => {
info!("Nothing to unregister");
}
Some(old_name) => {
warn!("Unregistered from old name {}", old_name);
}
}
let reply = match server_db_addr.send(RegisterServer { cert: peer_cert_bytes.clone(), name }).await.unwrap() {
Ok(_) => {
info!("Registered!");
ToServerCommandReply::NameChanged
}
Err(e) => {
error!("Error registering: {:?}", e);
ToServerCommandReply::GenericFailure
}
};
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
WhoAmI => {
info!("Asked who I am");
@@ -351,7 +260,7 @@ async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes:
}
// TODO: Considera creare un context dove vengono contenute tutte le chiavi e gli address da dare a tutti gli handler
async fn guestserver_handler(mut transport: TransportStream, server_db_addr: &Addr<ServerCertDB>, server_root_cert: &Arc<Certificate>) {
async fn guestserver_handler(mut transport: TransportStream, server_db_addr: Addr<ServerCertDB>, server_root_cert: &Arc<Certificate>) {
loop {
match transport.next().await {
None => {
@@ -396,3 +305,48 @@ async fn guestserver_handler(mut transport: TransportStream, server_db_addr: &Ad
}
}
}
async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerManager>) {
loop {
match transport.next().await {
None => {
info!("Transport returned None");
break;
}
Some(item) => {
match item {
Ok(buf) => {
let msg: FromClientCommand = rmp_serde::from_slice(&buf).unwrap();
info!("{:?}", msg);
match msg {
FromClientCommand::RequestServer { name } => {
info!("REQUESTED SERVER {}", name);
let data = sm_addr.send(RequestServer { name }).await.unwrap();
match data {
Ok(client_conn_id) => {
let reply = ToClientResponse::OkRequest { conn_id: client_conn_id };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
info!("Sent OkRequest");
}
Err(e) => {
error!("Error! {:?}", e);
}
}
}
FromClientCommand::ServerList => {
info!("Requested ServerList");
let data = sm_addr.send(GetServerList {}).await.unwrap();
let reply = ToClientResponse::OkServerList { data };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
}
}
}
Err(e) => {
info!("Disconnection: {:?}", e);
break;
}
}
}
}
}
}

View File

@@ -0,0 +1,149 @@
use actix::prelude::*;
use actix_tls::accept::rustls_0_22::TlsStream;
use thiserror::Error;
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{error, info};
use uuid::Uuid;
type TransportStream = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
/*
L'idea e' che il database deve avere una riga per ogni connessione dati in nascita.
In ogni "riga" deve essere presente:
- Il ServerConnID che il server dovra usare (pkey)
- Il ClientConnID che il client dovra usare (pkey)
- Se gia connesso, il Socket del Server
- Se gia connesso il Socket del Client
Quando in una riga sono presenti sia ServerSocket che ClientSocket allora poppa la riga
e usa i socket per lanciare un thread/actor che faccia il piping dei dati
Ricerca riga deve avvenire sia tramite ServerConnID che ClientConnID se essi sono diversi come pianifico
Quindi l'ideale e' non usare una collection ma andare direttamente di Vector!
*/
#[derive(Error, Debug)]
pub enum PendingDataConnError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
pub struct NewPendingConn {
pub server_conn_id: Uuid,
pub client_conn_id: Uuid,
}
#[derive(Debug)]
enum RegisterKind {
Server,
Client
}
#[derive(Message)]
#[rtype(result = "Result<(),PendingDataConnError>")]
pub struct RegisterStream {
kind: RegisterKind,
conn_id: Uuid,
transport: TransportStream,
}
impl RegisterStream {
pub fn server(conn_id: Uuid, transport: TransportStream) -> Self {
RegisterStream { kind: RegisterKind::Server, conn_id, transport }
}
pub fn client(conn_id: Uuid, transport: TransportStream) -> Self {
RegisterStream { kind: RegisterKind::Client, conn_id, transport }
}
}
struct SideRecord {
conn_id: Uuid,
transport: Option<TransportStream>,
}
struct Record {
server: SideRecord,
client: SideRecord,
}
impl Record {
fn new(server_conn_id: Uuid, client_conn_id: Uuid) -> Self {
let server = SideRecord { conn_id: client_conn_id, transport: None };
let client = SideRecord { conn_id: server_conn_id, transport: None };
Record { server, client }
}
}
// TODO: every 2 minutes verify the Records that have at least one stream invalidated and drop them
pub struct PendingDataConnManager {
db: Vec<Record>,
}
impl PendingDataConnManager {
pub fn new() -> Self {
PendingDataConnManager { db: Vec::new() }
}
}
impl Actor for PendingDataConnManager {
type Context = Context<Self>;
}
impl Handler<NewPendingConn> for PendingDataConnManager {
type Result = Result<(), PendingDataConnError>;
fn handle(&mut self, msg: NewPendingConn, _ctx: &mut Self::Context) -> Self::Result {
let server_conn_id = msg.server_conn_id;
let client_conn_id = msg.client_conn_id;
let new_record = Record::new(server_conn_id, client_conn_id);
if self.db.iter().any(|x| {
x.server.conn_id == new_record.server.conn_id || x.client.conn_id == new_record.client.conn_id
}) {
error!("Conflicting UUIDs: server {} - client {}", server_conn_id, client_conn_id);
Err(PendingDataConnError::GenericFailure)
} else {
self.db.push(new_record);
Ok(())
}
}
}
impl Handler<RegisterStream> for PendingDataConnManager {
type Result = Result<(), PendingDataConnError>;
fn handle(&mut self, msg: RegisterStream, _ctx: &mut Self::Context) -> Self::Result {
use RegisterKind::*;
let conn_id = msg.conn_id;
let record = match match msg.kind {
Server => self.db.iter_mut().find(|x| x.server.conn_id == conn_id),
Client => self.db.iter_mut().find(|x| x.client.conn_id == conn_id),
} {
None => {
error!("Found no connection with {:?} conn_id {:?}", msg.kind, conn_id);
return Err(PendingDataConnError::GenericFailure);
},
Some(item) => item,
};
let side_record = match msg.kind {
Server => &mut record.server,
Client => &mut record.client,
};
if side_record.transport.is_some() {
// TODO: It can be good to check if the connection is still open, if not, drop and use the new one.
error!("Connection already with a socket!");
return Err(PendingDataConnError::GenericFailure);
} else {
side_record.transport = Some(msg.transport);
}
if record.client.transport.is_some() && record.server.transport.is_some() {
// TODO: Launch the "thing" that will manage the data mirroring
info!("LAUNCHING DATA MIRRORING");
}
Ok(())
}
}

View File

@@ -1,5 +1,5 @@
use std::collections::HashMap;
use actix::{Actor, Context, Handler, Message};
use actix::prelude::*;
use thiserror::Error;
// TODO: Probably it's better to remove the pub from inside the structs and impl a new() funct
@@ -25,6 +25,12 @@ pub struct RegisterServer {
pub name: String,
}
#[derive(Message)]
#[rtype(result = "Option<String>")] // None if nothing to unregister, Some if unregistered
pub struct UnregisterServer {
pub cert: Vec<u8>,
}
#[derive(Message)]
#[rtype(result = "Option<String>")]
pub struct FetchName {
@@ -33,7 +39,13 @@ pub struct FetchName {
// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique
pub struct ServerCertDB {
pub db: HashMap<Vec<u8>, String>, // Cert to Name
db: HashMap<Vec<u8>, String>, // Cert to Name
}
impl ServerCertDB {
pub fn new() -> Self {
ServerCertDB { db: HashMap::new() }
}
}
impl Actor for ServerCertDB {
@@ -71,3 +83,11 @@ impl Handler<FetchName> for ServerCertDB {
self.db.get(&msg.cert).map(|s| s.to_owned())
}
}
impl Handler<UnregisterServer> for ServerCertDB {
type Result = Option<String>;
fn handle(&mut self, msg: UnregisterServer, _ctx: &mut Self::Context) -> Self::Result {
self.db.remove(&msg.cert)
}
}

View File

@@ -0,0 +1,334 @@
use std::collections::HashMap;
use std::io::Error;
use std::sync::{Arc};
use std::time::{Duration, Instant};
use actix::prelude::*;
use actix_server::Server;
use rand::random;
use thiserror::Error;
use futures::{StreamExt, SinkExt};
use tokio::sync::{Mutex, oneshot};
use tokio_util::bytes::{Bytes, BytesMut};
use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
use tracing::{debug, error, info};
use libbonknet::*;
use crate::{TransportStream, TransportStreamRx, TransportStreamTx};
use uuid::Uuid;
use crate::pendingdataconndb::*;
#[derive(MessageResponse)]
enum SendMsgResult {
Accepted,
Rejected,
}
#[derive(Message)]
#[rtype(result = "SendMsgResult")]
struct SendMsg {
msg: ToServerMessageBody,
reply_channel: oneshot::Sender<FromServerReplyBody>
}
#[derive(Error, Debug)]
pub enum AsyncSendMsgError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<FromServerReplyBody, AsyncSendMsgError>")]
struct AsyncSendMsg {
msg: ToServerMessageBody,
}
struct ServerTransporter {
rx: Option<TransportStreamRx>,
tx: Arc<tokio::sync::Mutex<TransportStreamTx>>,
timeout: Option<SpawnHandle>,
reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>,
}
impl Handler<AsyncSendMsg> for ServerTransporter {
type Result = ResponseFuture<Result<FromServerReplyBody, AsyncSendMsgError>>;
fn handle(&mut self, msg: AsyncSendMsg, _ctx: &mut Self::Context) -> Self::Result {
let (reply_channel_tx, reply_channel_rx) = oneshot::channel();
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return Box::pin(fut::ready(Err(AsyncSendMsgError::GenericFailure)));
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, reply_channel_tx);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
Box::pin(async move {
arc_tx.lock().await.send(payload).await.map_err(|e| AsyncSendMsgError::GenericFailure)?;
info!("msg sent");
let r = reply_channel_rx.await.unwrap();
info!("reply received");
Ok(r)
})
}
}
impl ServerTransporter {
fn new(transport: TransportStream) -> Self {
let internal = transport.into_inner();
let (srx, stx) = tokio::io::split(internal);
let codec = LengthDelimitedCodec::new();
let rx = FramedRead::new(srx, codec.clone());
let tx = FramedWrite::new(stx, codec.clone());
ServerTransporter {
rx: Some(rx),
tx: Arc::new(tokio::sync::Mutex::new(tx)),
timeout: None,
reply_channels: HashMap::new(),
}
}
fn actor_close(&mut self, ctx: &mut Context<Self>) {
ctx.stop();
}
}
impl Actor for ServerTransporter {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Register Read Stream
let rx = self.rx.take().expect("Rx Stream not found");
ctx.add_stream(rx);
// Register Timeout task
self.timeout = Some(ctx.run_interval(Duration::from_secs(120), |s, c| {
s.actor_close(c)
}));
// Register Send Ping Task
ctx.run_interval(Duration::from_secs(60), |s, c| {
let msg = ToServerMessage::Ping;
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = Arc::clone(&s.tx);
c.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(s).map(|res, a, ctx| {
match res {
Ok(_) => {
info!("Ping sent!");
}
Err(_) => {
ctx.stop();
}
}
}));
});
}
}
impl Handler<SendMsg> for ServerTransporter {
type Result = SendMsgResult;
fn handle(&mut self, msg: SendMsg, ctx: &mut Self::Context) -> Self::Result {
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return SendMsgResult::Rejected;
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, msg.reply_channel);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
ctx.spawn(async move {
arc_tx.lock().await.send(payload).await
}.into_actor(self).map(|res, _a, _ctx| {
info!("ToServerMsg sent result: {:?}", res);
}));
SendMsgResult::Accepted
}
}
impl StreamHandler<Result<BytesMut, Error>> for ServerTransporter {
fn handle(&mut self, item: Result<BytesMut, Error>, ctx: &mut Self::Context) {
match item {
Ok(buf) => {
use FromServerReply::*;
let msg: FromServerReply = rmp_serde::from_slice(&buf).unwrap();
match msg {
Pong => {
info!("Pong received");
if let Some(spawn_handle) = self.timeout {
ctx.cancel_future(spawn_handle);
} else {
error!("There were no spawn handle configured!");
}
self.timeout = Some(ctx.run_interval(Duration::from_secs(120), |s, c| {
s.actor_close(c)
}));
}
Msg { reply_id, body } => match self.reply_channels.remove(&reply_id) {
None => {}
Some(reply_tx) => {
if let Err(_e) = reply_tx.send(body) {
error!("Oneshot channel {} got invalidated! No reply sent.", reply_id);
}
}
}
}
}
Err(e) => {
error!("ERROR {:?}", e);
}
}
}
}
#[derive(Error, Debug)]
pub enum ServerManagerError {
#[error("Generic Failure")]
GenericFailure,
}
#[derive(Message)]
#[rtype(result = "Result<(),ServerManagerError>")]
pub struct StartTransporter {
pub name: String,
pub transport: TransportStream,
}
#[derive(Message)]
#[rtype(result = "Vec<String>")]
pub struct GetServerList {}
#[derive(Message)]
#[rtype(result = "Result<Uuid,ServerManagerError>")] // TODO: Return Client ID with struct to give it a name
pub struct RequestServer {
pub name: String
}
pub struct ServerManager {
entries: Arc<Mutex<HashMap<String, Addr<ServerTransporter>>>>,
// Name -> Addr to Actor
pdcdb_addr: Addr<PendingDataConnManager>,
}
impl ServerManager {
pub fn new(pdcdb_addr: Addr<PendingDataConnManager>) -> Self {
ServerManager { entries: Arc::new(Mutex::new(HashMap::new())), pdcdb_addr }
}
}
impl Actor for ServerManager {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
// Remove all the ServerTransporters that are not running
// TODO: This is a "Pull" from entries, but we can have the entries that when killed tell
// the Manager making it more reactive!
ctx.run_interval(Duration::from_secs(10), |s, c| {
let start = Instant::now();
let entries = Arc::clone(&s.entries);
c.spawn(async move {
let mut entries_mg = entries.lock().await;
let mut keys_to_delete = vec![];
for (name, st_addr) in entries_mg.iter() {
if !st_addr.connected() {
keys_to_delete.push(name.clone())
}
}
for name in keys_to_delete {
entries_mg.remove(&name);
info!("Closed ServerTransporter {} for actor death", name);
}
debug!("Cleaned ServerManager in {:?}", Instant::now() - start);
}.into_actor(s));
});
}
}
impl Handler<StartTransporter> for ServerManager {
type Result = ResponseFuture<Result<(), ServerManagerError>>;
fn handle(&mut self, msg: StartTransporter, _ctx: &mut Self::Context) -> Self::Result {
let mut entries = Arc::clone(&self.entries);
Box::pin(async move {
let mut entries_mg = entries.lock().await;
if entries_mg.contains_key(&msg.name) {
error!("A server called {} is already connected!", msg.name);
return Err(ServerManagerError::GenericFailure);
}
let st_addr = ServerTransporter::new(msg.transport).start();
entries_mg.insert(msg.name, st_addr);
Ok(())
})
}
}
impl Handler<GetServerList> for ServerManager {
type Result = ResponseFuture<Vec<String>>;
fn handle(&mut self, _msg: GetServerList, _ctx: &mut Self::Context) -> Self::Result {
let entries = Arc::clone(&self.entries);
Box::pin(async move {
let entries_mg = entries.lock().await;
entries_mg.keys().cloned().collect()
})
}
}
impl Handler<RequestServer> for ServerManager {
type Result = ResponseFuture<Result<Uuid, ServerManagerError>>;
fn handle(&mut self, msg: RequestServer, _ctx: &mut Self::Context) -> Self::Result {
let name = msg.name.clone();
let entries = Arc::clone(&self.entries);
let pdcdb_addr = self.pdcdb_addr.clone();
Box::pin(async move {
let lock = entries.lock().await;
let sh_addr = match lock.get(&name) {
None => {
error!("Requested server {} that isn't registered", name);
return Err(ServerManagerError::GenericFailure);
}
Some(item) => item,
};
let server_conn_id = Uuid::new_v4();
let client_conn_id = Uuid::new_v4();
match pdcdb_addr.send(NewPendingConn { server_conn_id, client_conn_id }).await.unwrap() {
Ok(_) => {
let msg = ToServerMessageBody::Request { conn_id: server_conn_id };
match sh_addr.send(AsyncSendMsg { msg }).await.unwrap() {
Ok(reply) => match reply {
FromServerReplyBody::RequestAccepted => {
Ok(client_conn_id)
}
FromServerReplyBody::RequestFailed => {
error!("Request Failed!");
Err(ServerManagerError::GenericFailure)
}
FromServerReplyBody::Pong => unreachable!(),
}
Err(e) => {
panic!("ERROR: {:?}", e);
}
}
}
Err(_e) => Err(ServerManagerError::GenericFailure),
}
})
}
}

View File

@@ -7,10 +7,12 @@ edition = "2021"
[dependencies]
libbonknet = { path = "../libbonknet" }
tokio = { version = "1", features = ["full"] }
tokio-rustls = "0.25.0"
tokio = { version = "1", features = ["full", "tracing"] }
tokio-util = { version = "0.7.10", features = ["codec"] }
futures = "0.3"
rcgen = "0.12.0"
tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0"
serde = { version = "1.0", features = ["derive"] }
rmp-serde = "1.1.2"
tracing = "0.1"
tracing-subscriber = "0.3"

View File

@@ -1,38 +1,35 @@
use std::io::Error;
use std::sync::Arc;
use futures::SinkExt;
use std::time::Duration;
use futures::{StreamExt, SinkExt};
use tokio::net::TcpStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::rustls::pki_types::{ServerName};
use tokio_rustls::rustls::pki_types::{ServerName, CertificateDer, PrivatePkcs8KeyDer};
use tokio_rustls::TlsConnector;
use tokio_util::bytes::BytesMut;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use serde::{Serialize, Deserialize};
use libbonknet::{load_cert, load_prkey};
use libbonknet::*;
use tracing::{info, error};
#[derive(Debug, Serialize, Deserialize)]
enum ClientMessage {
Response { status_code: u32, msg: Option<String> },
Announce { name: String },
Required { id: String },
NotRequired { id: String },
}
// TODO: This is an old examples
#[tokio::main]
async fn main() -> std::io::Result<()> {
let client_name = "Polnareffland1";
// Tracing Subscriber
let subscriber = tracing_subscriber::FmtSubscriber::new();
tracing::subscriber::set_global_default(subscriber).unwrap();
// Root certs to verify the server is the right one
let mut server_root_cert_store = RootCertStore::empty();
let server_root_cert_der = load_cert("server_root_cert.pem").unwrap();
server_root_cert_store.add(server_root_cert_der).unwrap();
// Auth Cert to send the server who am I
let root_client_cert = load_cert("client_root_cert.pem").unwrap();
let client_cert = load_cert("client_cert.pem").unwrap();
let client_prkey = load_prkey("client_key.pem").unwrap();
let mut broker_root_cert_store = RootCertStore::empty();
let broker_root_cert_der = load_cert("certs/broker_root_cert.pem").unwrap();
broker_root_cert_store.add(broker_root_cert_der).unwrap();
// Public CA for Clients
let root_client_cert = load_cert("certs/client_root_cert.pem").unwrap();
// My Client Certificate for authentication
let client_cert = load_cert("certs/client_cert.pem").unwrap();
let client_cert_prkey = load_prkey("certs/client_key.pem").unwrap();
// Load TLS Config
let tlsconfig = ClientConfig::builder()
.with_root_certificates(server_root_cert_store)
// .with_no_client_auth();
.with_client_auth_cert(vec![client_cert, root_client_cert], client_prkey.into())
.with_root_certificates(broker_root_cert_store.clone())
.with_client_auth_cert(vec![client_cert, root_client_cert], client_cert_prkey.into())
.unwrap();
let connector = TlsConnector::from(Arc::new(tlsconfig));
let dnsname = ServerName::try_from("localhost").unwrap();
@@ -41,34 +38,46 @@ async fn main() -> std::io::Result<()> {
let stream = connector.connect(dnsname, stream).await?;
let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
let msg1 = ClientMessage::Announce { name: client_name.into() };
transport.send(rmp_serde::to_vec(&msg1).unwrap().into()).await.unwrap();
for i in 0..10 {
let msg = ClientMessage::Response { status_code: 100+i, msg: Some(format!("yay {}", i)) };
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let msg = FromClientCommand::ServerList;
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToClientResponse::*;
let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkRequest { .. } => error!("Wrong reply!"),
OkServerList { data } => info!("{}", data.join("\n")),
GenericError => error!("Generic Error during remote command execution"),
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
}
tokio::time::sleep(Duration::from_secs(5)).await;
let msg = FromClientCommand::RequestServer { name: "cicciopizza".into() };
transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap();
match transport.next().await {
None => panic!("None in the transport"),
Some(item) => match item {
Ok(buf) => {
use ToClientResponse::*;
let msg: ToClientResponse = rmp_serde::from_slice(&buf).unwrap();
match msg {
OkRequest { conn_id } => {
info!("Received Client Connection ID: {:?}", conn_id);
}
OkServerList { .. } => error!("Wrong reply!"),
GenericError => error!("Generic Error during remote command execution"),
}
}
Err(e) => {
error!("Error: {:?}", e);
}
}
}
// transport.for_each(|item| async move {
// let a: ClientMessage = rmp_serde::from_slice(&item.unwrap()).unwrap();
// println!("{:?}", a);
// }).await;
// let mut buf = vec![0;1024];
// let (mut rd,mut tx) = split(stream);
//
//
// tokio::spawn(async move {
// let mut stdout = tokio::io::stdout();
// tokio::io::copy(&mut rd, &mut stdout).await.unwrap();
// });
//
// let mut reader = tokio::io::BufReader::new(tokio::io::stdin()).lines();
//
// while let Some(line) = reader.next_line().await.unwrap() {
// tx.write_all(line.as_bytes()).await.unwrap();
// }
Ok(())
}

View File

@@ -190,11 +190,11 @@ async fn main() -> std::io::Result<()> {
Msg { reply_id, body } => {
use ToServerMessageBody::*;
match body {
Required { id } => {
info!("I'm required with Connection ID {}", id);
Request { conn_id } => {
info!("I'm required with Connection ID {}", conn_id);
out = Some(FromServerReply::Msg {
reply_id,
body: FromServerReplyBody::RequiredAccepted,
body: FromServerReplyBody::RequestAccepted,
})
}
}

View File

@@ -9,3 +9,4 @@ edition = "2021"
tokio-rustls = "0.25.0"
rustls-pemfile = "2.0.0"
serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.7.0", features = ["serde"] }

View File

@@ -2,6 +2,7 @@ use std::io::{BufReader, Error, ErrorKind};
use rustls_pemfile::{read_one, Item};
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use uuid::Uuid;
pub fn load_cert(filename: &str) -> std::io::Result<CertificateDer> {
let cert_file = std::fs::File::open(filename).unwrap();
@@ -60,7 +61,7 @@ pub enum YouAreValues {
#[derive(Debug, Serialize, Deserialize)]
pub enum ToServerMessageBody {
Required { id: String },
Request { conn_id: Uuid },
}
#[derive(Debug, Serialize, Deserialize)]
@@ -74,8 +75,8 @@ pub enum ToServerMessage {
#[derive(Debug, Serialize, Deserialize)]
pub enum FromServerReplyBody {
RequiredAccepted,
RequiredFailed,
RequestAccepted,
RequestFailed,
Pong,
}
@@ -113,3 +114,17 @@ impl ToGuestServerMessage {
}
}
}
// Client things
#[derive(Debug, Serialize, Deserialize)]
pub enum FromClientCommand {
RequestServer { name: String },
ServerList,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ToClientResponse {
OkRequest { conn_id: Uuid },
OkServerList { data: Vec<String> },
GenericError,
}