Change name from AsyncSendMsg to SendMsg and remove old version

This commit is contained in:
2024-02-19 16:54:37 +01:00
parent 37cc133d7f
commit 69a37ae89a
2 changed files with 29 additions and 56 deletions

View File

@@ -102,6 +102,7 @@ async fn main() {
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let scdb_addr = scdb_addr.clone(); let scdb_addr = scdb_addr.clone();
let pdcm_addr = pdcm_addr.clone();
let sm_addr = sm_addr.clone(); let sm_addr = sm_addr.clone();
// Set up TLS service factory // Set up TLS service factory
@@ -114,13 +115,15 @@ async fn main() {
let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der); let guestserver_root_cert_der = Arc::clone(&guestserver_root_cert_der);
let server_root_cert = Arc::clone(&server_root_cert); let server_root_cert = Arc::clone(&server_root_cert);
let scdb_addr = scdb_addr.clone(); let scdb_addr = scdb_addr.clone();
let pdcm_addr = pdcm_addr.clone();
let sm_addr = sm_addr.clone(); let sm_addr = sm_addr.clone();
async move { async move {
let peer_certs = stream.get_ref().1.peer_certificates().unwrap(); let peer_certs = stream.get_ref().1.peer_certificates().unwrap();
let peer_cert_bytes = peer_certs.first().unwrap().to_vec(); let peer_cert_bytes = peer_certs.first().unwrap().to_vec();
let peer_root_cert_der = peer_certs.last().unwrap().clone(); let peer_root_cert_der = peer_certs.last().unwrap().clone();
let scdb_addr = scdb_addr.clone(); // let scdb_addr = scdb_addr.clone();
let sm_addr = sm_addr.clone(); // let pdcm_addr = pdcm_addr.clone();
// let sm_addr = sm_addr.clone();
if peer_root_cert_der == *server_root_cert_der { if peer_root_cert_der == *server_root_cert_der {
info!("Server connection"); info!("Server connection");
let mut transport = Framed::new(stream, LengthDelimitedCodec::new()); let mut transport = Framed::new(stream, LengthDelimitedCodec::new());
@@ -155,6 +158,10 @@ async fn main() {
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
server_subscribe_handler(transport, name, sm_addr).await; server_subscribe_handler(transport, name, sm_addr).await;
} }
OpenDataStream(conn_id) => {
info!("OpenDataStream with {:?}", conn_id);
// TODO: OpenDataStream
}
} }
} }
Err(e) => { Err(e) => {
@@ -171,7 +178,6 @@ async fn main() {
guestserver_handler(transport, scdb_addr, &server_root_cert).await; guestserver_handler(transport, scdb_addr, &server_root_cert).await;
} else if peer_root_cert_der == *client_root_cert_der { } else if peer_root_cert_der == *client_root_cert_der {
info!("Client connection"); info!("Client connection");
//pdcm_addr
let codec = LengthDelimitedCodec::new(); let codec = LengthDelimitedCodec::new();
let transport = Framed::new(stream, codec); let transport = Framed::new(stream, codec);
client_handler(transport, sm_addr).await; client_handler(transport, sm_addr).await;
@@ -339,6 +345,10 @@ async fn client_handler(mut transport: TransportStream, sm_addr: Addr<ServerMana
let reply = ToClientResponse::OkServerList { data }; let reply = ToClientResponse::OkServerList { data };
transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap(); transport.send(rmp_serde::to_vec(&reply).unwrap().into()).await.unwrap();
} }
FromClientCommand::UpgradeToDataStream(conn_id) => {
info!("Upgrade to DataStream with conn_id {:?}", conn_id);
// TODO: Upgrade to DataStream
}
} }
} }
Err(e) => { Err(e) => {

View File

@@ -22,22 +22,15 @@ enum SendMsgResult {
Rejected, Rejected,
} }
#[derive(Message)]
#[rtype(result = "SendMsgResult")]
struct SendMsg {
msg: ToServerMessageBody,
reply_channel: oneshot::Sender<FromServerReplyBody>
}
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum AsyncSendMsgError { pub enum SendMsgError {
#[error("Generic Failure")] #[error("Generic Failure")]
GenericFailure, GenericFailure,
} }
#[derive(Message)] #[derive(Message)]
#[rtype(result = "Result<FromServerReplyBody, AsyncSendMsgError>")] #[rtype(result = "Result<FromServerReplyBody, SendMsgError>")]
struct AsyncSendMsg { struct SendMsg {
msg: ToServerMessageBody, msg: ToServerMessageBody,
} }
@@ -48,38 +41,6 @@ struct ServerTransporter {
reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>, reply_channels: HashMap<u64, oneshot::Sender<FromServerReplyBody>>,
} }
impl Handler<AsyncSendMsg> for ServerTransporter {
type Result = ResponseFuture<Result<FromServerReplyBody, AsyncSendMsgError>>;
fn handle(&mut self, msg: AsyncSendMsg, _ctx: &mut Self::Context) -> Self::Result {
let (reply_channel_tx, reply_channel_rx) = oneshot::channel();
let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize {
return Box::pin(fut::ready(Err(AsyncSendMsgError::GenericFailure)));
}
loop {
reply_id = random();
if !self.reply_channels.contains_key(&reply_id) {
break;
}
}
self.reply_channels.insert(reply_id, reply_channel_tx);
let msg = ToServerMessage::Msg {
reply_id,
body: msg.msg,
};
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone();
Box::pin(async move {
arc_tx.lock().await.send(payload).await.map_err(|e| AsyncSendMsgError::GenericFailure)?;
info!("msg sent");
let r = reply_channel_rx.await.unwrap();
info!("reply received");
Ok(r)
})
}
}
impl ServerTransporter { impl ServerTransporter {
fn new(transport: TransportStream) -> Self { fn new(transport: TransportStream) -> Self {
let internal = transport.into_inner(); let internal = transport.into_inner();
@@ -133,12 +94,13 @@ impl Actor for ServerTransporter {
} }
impl Handler<SendMsg> for ServerTransporter { impl Handler<SendMsg> for ServerTransporter {
type Result = SendMsgResult; type Result = ResponseFuture<Result<FromServerReplyBody, SendMsgError>>;
fn handle(&mut self, msg: SendMsg, ctx: &mut Self::Context) -> Self::Result { fn handle(&mut self, msg: SendMsg, _ctx: &mut Self::Context) -> Self::Result {
let (reply_channel_tx, reply_channel_rx) = oneshot::channel();
let mut reply_id: u64; let mut reply_id: u64;
if self.reply_channels.len() == u64::MAX as usize { if self.reply_channels.len() == u64::MAX as usize {
return SendMsgResult::Rejected; return Box::pin(fut::ready(Err(SendMsgError::GenericFailure)));
} }
loop { loop {
reply_id = random(); reply_id = random();
@@ -146,19 +108,20 @@ impl Handler<SendMsg> for ServerTransporter {
break; break;
} }
} }
self.reply_channels.insert(reply_id, msg.reply_channel); self.reply_channels.insert(reply_id, reply_channel_tx);
let msg = ToServerMessage::Msg { let msg = ToServerMessage::Msg {
reply_id, reply_id,
body: msg.msg, body: msg.msg,
}; };
let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into(); let payload: Bytes = rmp_serde::to_vec(&msg).unwrap().into();
let arc_tx = self.tx.clone(); let arc_tx = self.tx.clone();
ctx.spawn(async move { Box::pin(async move {
arc_tx.lock().await.send(payload).await arc_tx.lock().await.send(payload).await.map_err(|e| SendMsgError::GenericFailure)?;
}.into_actor(self).map(|res, _a, _ctx| { info!("msg sent");
info!("ToServerMsg sent result: {:?}", res); let r = reply_channel_rx.await.unwrap();
})); info!("reply received");
SendMsgResult::Accepted Ok(r)
})
} }
} }
@@ -311,7 +274,7 @@ impl Handler<RequestServer> for ServerManager {
match pdcdb_addr.send(NewPendingConn { server_conn_id, client_conn_id }).await.unwrap() { match pdcdb_addr.send(NewPendingConn { server_conn_id, client_conn_id }).await.unwrap() {
Ok(_) => { Ok(_) => {
let msg = ToServerMessageBody::Request { conn_id: server_conn_id }; let msg = ToServerMessageBody::Request { conn_id: server_conn_id };
match sh_addr.send(AsyncSendMsg { msg }).await.unwrap() { match sh_addr.send(SendMsg { msg }).await.unwrap() {
Ok(reply) => match reply { Ok(reply) => match reply {
FromServerReplyBody::RequestAccepted => { FromServerReplyBody::RequestAccepted => {
Ok(client_conn_id) Ok(client_conn_id)