diff --git a/Cargo.lock b/Cargo.lock index 76a17ef..08f5b83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -220,10 +220,13 @@ dependencies = [ "actix-tls", "futures", "libbonknet", + "rand", "rcgen", "rmp-serde", "rustls", + "serde", "thiserror", + "tokio", "tokio-util", "tracing", "tracing-subscriber", @@ -702,6 +705,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.76" @@ -720,6 +729,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "rcgen" version = "0.12.1" diff --git a/bonknet_broker/Cargo.toml b/bonknet_broker/Cargo.toml index 25c9e48..9ac6de5 100644 --- a/bonknet_broker/Cargo.toml +++ b/bonknet_broker/Cargo.toml @@ -13,14 +13,17 @@ actix-rt = "2.9.0" actix-server = "2.3.0" actix-service = "2.0.2" actix-tls = { version = "3.3.0", features = ["rustls-0_22"] } +tokio = { version = "1", features = ["io-util", "sync", "time"] } rustls = "0.22.2" tracing = "0.1" tracing-subscriber = "0.3" futures = "0.3" thiserror = "1.0.56" tokio-util = { version = "0.7.10", features = ["codec"] } +serde = "1" rmp-serde = "1.1.2" rcgen = { version = "0.12.1", features = ["x509-parser"] } +rand = "0.8.5" [[bin]] name = "init_certs" diff --git a/bonknet_broker/src/main.rs b/bonknet_broker/src/main.rs index fc62b3f..01e091c 100644 --- a/bonknet_broker/src/main.rs +++ b/bonknet_broker/src/main.rs @@ -1,6 +1,11 @@ +mod servercertdb; + +use servercertdb::*; use actix::prelude::*; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc}; +use std::time::{Instant, Duration}; +use actix::fut::wrap_future; use libbonknet::*; use rustls::{RootCertStore, ServerConfig}; use rustls::server::WebPkiClientVerifier; @@ -9,12 +14,18 @@ use actix_server::Server; use actix_rt::net::TcpStream; use actix_service::{ServiceFactoryExt as _}; use futures::{StreamExt, SinkExt}; -use thiserror::Error; -use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use rand::random; +use tokio_util::codec::{Framed, FramedRead, FramedWrite, LengthDelimitedCodec}; use tracing::{info, error}; use rcgen::{Certificate, CertificateParams, DnType, KeyPair}; +use tokio::io::{ReadHalf, WriteHalf}; +use tokio_util::bytes::{Bytes, BytesMut}; +use tokio::io::Error; +use tokio::sync::{oneshot, Mutex}; type TransportStream = Framed, LengthDelimitedCodec>; +type TransportStreamTx = FramedWrite>, LengthDelimitedCodec>; +type TransportStreamRx = FramedRead>, LengthDelimitedCodec>; struct ServerCert { cert: Vec, @@ -35,71 +46,128 @@ fn generate_server_cert(root_cert: &Certificate, name: &str) -> ServerCert { ServerCert { cert, prkey } } -#[derive(Error, Debug)] -enum DBError { - #[error("Certificate is already registered with name {0}")] - CertAlreadyRegistered(String), - // #[error("Generic Failure")] - // GenericFailure, + +#[derive(MessageResponse)] +enum SendMsgResult { + Accepted, + Rejected, } #[derive(Message)] -#[rtype(result = "bool")] -struct IsNameRegistered { - name: String, +#[rtype(result = "SendMsgResult")] +struct SendMsg { + msg: ToServerMessageBody, + reply_channel: oneshot::Sender } #[derive(Message)] -#[rtype(result = "Result<(), DBError>")] -struct RegisterServer { - cert: Vec, - name: String, +#[rtype(result = "()")] +struct SendPing; + +struct ServerTransporter { + rx: Option, + tx: Arc>, + last_transmission: Instant, + reply_channels: HashMap>, } -#[derive(Message)] -#[rtype(result = "Option")] -struct FetchName { - cert: Vec, -} - -// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique -struct ServerCertDB { - db: HashMap, String>, // Cert to Name -} - -impl Actor for ServerCertDB { - type Context = Context; -} - -impl Handler for ServerCertDB { - type Result = Result<(), DBError>; - - fn handle(&mut self, msg: RegisterServer, _ctx: &mut Self::Context) -> Self::Result { - match self.db.get(&msg.cert) { - None => { - self.db.insert(msg.cert, msg.name); - Ok(()) - } - Some(name) => { - Err(DBError::CertAlreadyRegistered(name.clone())) - } +impl ServerTransporter { + fn new(transport: TransportStream) -> Self { + let internal = transport.into_inner(); + let (srx, stx) = tokio::io::split(internal); + let codec = LengthDelimitedCodec::new(); + let rx = FramedRead::new(srx, codec.clone()); + let tx = FramedWrite::new(stx, codec.clone()); + ServerTransporter { + rx: Some(rx), + tx: Arc::new(Mutex::new(tx)), + last_transmission: Instant::now(), + reply_channels: HashMap::new(), } } } -impl Handler for ServerCertDB { - type Result = bool; +impl Actor for ServerTransporter { + type Context = Context; - fn handle(&mut self, msg: IsNameRegistered, _ctx: &mut Self::Context) -> Self::Result { - self.db.values().any(|x| *x == msg.name) + fn started(&mut self, ctx: &mut Self::Context) { + let rx = self.rx.take().expect("Rx Stream not found"); + ctx.add_stream(rx); + ctx.run_interval(Duration::from_secs(60), |_s, c| { + c.notify(SendPing); + }); } } -impl Handler for ServerCertDB { - type Result = Option; +impl Handler for ServerTransporter { + type Result = (); - fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result { - self.db.get(&msg.cert).map(|s| s.to_owned()) + fn handle(&mut self, _msg: SendPing, ctx: &mut Self::Context) -> Self::Result { + let msg = ToServerMessage::Ping; + let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into(); + let arc_tx = self.tx.clone(); + ctx.spawn(wrap_future::<_, Self>(async move { + arc_tx.lock().await.send(payload).await + }).map(|res, _a, _ctx| { + info!("Ping sent result: {:?}", res); + })); + } +} + +impl Handler for ServerTransporter { + type Result = SendMsgResult; + + fn handle(&mut self, msg: SendMsg, ctx: &mut Self::Context) -> Self::Result { + let mut reply_id: u64; + if self.reply_channels.len() == u64::MAX as usize { + return SendMsgResult::Rejected; + } + loop { + reply_id = random(); + if !self.reply_channels.contains_key(&reply_id) { + break; + } + } + self.reply_channels.insert(reply_id, msg.reply_channel); + let msg = ToServerMessage::Msg { + reply_id, + body: msg.msg, + }; + let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into(); + let arc_tx = self.tx.clone(); + ctx.spawn(async move { + arc_tx.lock().await.send(payload).await + }.into_actor(self).map(|res, _a, _ctx| { + info!("ToServerMsg sent result: {:?}", res); + })); + SendMsgResult::Accepted + } +} + +impl StreamHandler> for ServerTransporter { + fn handle(&mut self, item: Result, _ctx: &mut Self::Context) { + match item { + Ok(buf) => { + use FromServerReply::*; + let msg: FromServerReply = rmp_serde::from_slice(&buf).unwrap(); + match msg { + Pong => { + self.last_transmission = Instant::now(); + } + Msg { reply_id, body } => match self.reply_channels.remove(&reply_id) { + None => {} + Some(reply_tx) => { + if let Err(_e) = reply_tx.send(body) { + error!("Oneshot channel {} got invalidated! No reply sent.", reply_id); + } + } + } + } + } + Err(e) => { + error!("ERROR {:?}", e); + } + } } } @@ -189,7 +257,12 @@ async fn main() { server_command_handler(transport, peer_cert_bytes, &server_db_addr).await; } Subscribe => { - info!("Subscribe Stream") + info!("Subscribe Stream"); + let reply = ToServerConnTypeReply::OkSubscribe; + transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); + // TODO: If I pass transport away and the service returns, what happen to the connection? + // in theory it will remain open but better check + server_subscribe_handler(transport).await; } } } @@ -220,6 +293,21 @@ async fn main() { .unwrap(); } +async fn server_subscribe_handler(transport: TransportStream) { + let h = ServerTransporter::new(transport).start(); + info!("Actor spawned"); + tokio::time::sleep(Duration::from_secs(5)).await; + info!("5 seconds elapsed, sending msg"); + let (tx, rx) = oneshot::channel(); + h.send(SendMsg { + msg: ToServerMessageBody::Required { id: "session_id".to_string() }, + reply_channel: tx, + }).await.unwrap(); + if let Ok(item) = rx.await { + info!("Response: {:?}", item); + } +} + async fn server_command_handler(mut transport: TransportStream, peer_cert_bytes: Vec, server_db_addr: &Addr) { loop { match transport.next().await { diff --git a/bonknet_broker/src/servercertdb.rs b/bonknet_broker/src/servercertdb.rs new file mode 100644 index 0000000..7951ad1 --- /dev/null +++ b/bonknet_broker/src/servercertdb.rs @@ -0,0 +1,73 @@ +use std::collections::HashMap; +use actix::{Actor, Context, Handler, Message}; +use thiserror::Error; + +// TODO: Probably it's better to remove the pub from inside the structs and impl a new() funct + +#[derive(Error, Debug)] +pub enum DBError { + #[error("Certificate is already registered with name {0}")] + CertAlreadyRegistered(String), + // #[error("Generic Failure")] + // GenericFailure, +} + +#[derive(Message)] +#[rtype(result = "bool")] +pub struct IsNameRegistered { + pub name: String, +} + +#[derive(Message)] +#[rtype(result = "Result<(), DBError>")] +pub struct RegisterServer { + pub cert: Vec, + pub name: String, +} + +#[derive(Message)] +#[rtype(result = "Option")] +pub struct FetchName { + pub cert: Vec, +} + +// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique +pub struct ServerCertDB { + pub db: HashMap, String>, // Cert to Name +} + +impl Actor for ServerCertDB { + type Context = Context; +} + +impl Handler for ServerCertDB { + type Result = Result<(), DBError>; + + fn handle(&mut self, msg: RegisterServer, _ctx: &mut Self::Context) -> Self::Result { + match self.db.get(&msg.cert) { + None => { + self.db.insert(msg.cert, msg.name); + Ok(()) + } + Some(name) => { + Err(DBError::CertAlreadyRegistered(name.clone())) + } + } + } +} + +impl Handler for ServerCertDB { + type Result = bool; + + fn handle(&mut self, msg: IsNameRegistered, _ctx: &mut Self::Context) -> Self::Result { + self.db.values().any(|x| *x == msg.name) + } +} + +impl Handler for ServerCertDB { + type Result = Option; + + fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result { + self.db.get(&msg.cert).map(|s| s.to_owned()) + } +} \ No newline at end of file diff --git a/bonknet_server/src/bin/server.rs b/bonknet_server/src/bin/server.rs index 3574a69..949319b 100644 --- a/bonknet_server/src/bin/server.rs +++ b/bonknet_server/src/bin/server.rs @@ -175,22 +175,46 @@ async fn main() -> std::io::Result<()> { } } // Subscribe consume - transport.for_each(|item| async move { - match item { - Ok(buf) => { - use ToServerMessage::*; - let msg: ToServerMessage = rmp_serde::from_slice(&buf).unwrap(); - match msg { - Required { id } => { - info!("I'm required with Connection ID {}", id); + loop { + match transport.next().await { + None => { + info!("Empty Buffer"); + } + Some(item) => { + let mut out: Option = None; + match item { + Ok(buf) => { + use ToServerMessage::*; + let msg: ToServerMessage = rmp_serde::from_slice(&buf).unwrap(); + match msg { + Msg { reply_id, body } => { + use ToServerMessageBody::*; + match body { + Required { id } => { + info!("I'm required with Connection ID {}", id); + out = Some(FromServerReply::Msg { + reply_id, + body: FromServerReplyBody::RequiredAccepted, + }) + } + } + } + Ping => { + info!("Ping!"); + out = Some(FromServerReply::Pong); + } + } + } + Err(e) => { + error!("Error: {:?}", e); } } - } - Err(e) => { - error!("Error: {:?}", e); + if let Some(msg) = out { + transport.send(rmp_serde::to_vec(&msg).unwrap().into()).await.unwrap(); + } } } - }).await; + } } Ok(()) } \ No newline at end of file diff --git a/libbonknet/src/lib.rs b/libbonknet/src/lib.rs index c6f067f..9e4e869 100644 --- a/libbonknet/src/lib.rs +++ b/libbonknet/src/lib.rs @@ -25,12 +25,6 @@ pub fn load_prkey(filename: &str) -> std::io::Result { } } -#[derive(Debug, Serialize, Deserialize)] -pub enum RequiredReplyValues { - Ok, - GenericFailure { status_code: u32, msg: Option }, -} - #[derive(Debug, Serialize, Deserialize)] pub enum FromServerConnTypeMessage { SendCommand, @@ -65,10 +59,35 @@ pub enum YouAreValues { } #[derive(Debug, Serialize, Deserialize)] -pub enum ToServerMessage { +pub enum ToServerMessageBody { Required { id: String }, } +#[derive(Debug, Serialize, Deserialize)] +pub enum ToServerMessage { + Ping, + Msg { + reply_id: u64, + body: ToServerMessageBody, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum FromServerReplyBody { + RequiredAccepted, + RequiredFailed, + Pong, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum FromServerReply { + Pong, + Msg { + reply_id: u64, + body: FromServerReplyBody + } +} + #[derive(Debug, Serialize, Deserialize)] pub enum FromGuestServerMessage { Announce { name: String }