diff --git a/bonknet_broker/Cargo.toml b/bonknet_broker/Cargo.toml index 470a4f8..d541621 100644 --- a/bonknet_broker/Cargo.toml +++ b/bonknet_broker/Cargo.toml @@ -27,6 +27,7 @@ rand = "0.8.5" uuid = { version = "1.7.0", features = ["v4", "serde"] } rustls-pemfile = "2.0.0" x509-parser = "0.16.0" +rusqlite = { version = "0.31.0", features = ["bundled"] } [[bin]] name = "init_certs" diff --git a/bonknet_broker/src/main.rs b/bonknet_broker/src/main.rs index 2a12136..12abe40 100644 --- a/bonknet_broker/src/main.rs +++ b/bonknet_broker/src/main.rs @@ -47,7 +47,7 @@ async fn main() { let server_ca = CACertPair::load_from_file("certs_pem/server_root_ca.pem").unwrap(); let guestserver_ca = CACertPair::load_from_file("certs_pem/guestserver_root_ca.pem").unwrap(); // Load Actors - let servercert_db = ServerCertDB::new().start(); + let servercert_db = ServerCertDB::new("certsdb.sqlite").unwrap().start(); let dataconn_manager = DataConnManager::new().start(); let pendingdataconn_manager = PendingDataConnManager::new(dataconn_manager).start(); let server_manager = ServerManager::new(pendingdataconn_manager.clone()).start(); diff --git a/bonknet_broker/src/servercertdb.rs b/bonknet_broker/src/servercertdb.rs index d45fd89..8118f67 100644 --- a/bonknet_broker/src/servercertdb.rs +++ b/bonknet_broker/src/servercertdb.rs @@ -1,13 +1,14 @@ -use std::collections::HashMap; +use std::path::Path; use actix::prelude::*; +use rusqlite::{Connection}; use thiserror::Error; #[derive(Error, Debug)] pub enum DBError { #[error("Certificate is already registered with name {0}")] CertAlreadyRegistered(String), - // #[error("Generic Failure")] - // GenericFailure, + #[error("Generic Failure")] + GenericFailure, } #[derive(Message)] @@ -35,14 +36,29 @@ pub struct FetchName { pub cert: Vec, } -// TODO: Move into Sqlite DB with unique check on col1 and col2!!!! Right now name is not unique +fn init_db(conn: Connection) -> rusqlite::Result { + conn.execute( + "CREATE TABLE IF NOT EXISTS servercert ( + cert BLOB PRIMARY KEY, + name TEXT NOT NULL + )", + (), // empty list of parameters. + )?; + Ok(conn) +} + +// TODO: Right now name is not unique. Consider making it unique and checking it for duplication pub struct ServerCertDB { - db: HashMap, String>, // Cert to Name + conn: Connection, } impl ServerCertDB { - pub fn new() -> Self { - ServerCertDB { db: HashMap::new() } + fn new_in_memory() -> rusqlite::Result { + Ok(ServerCertDB { conn: init_db(Connection::open_in_memory()?)? }) + } + + pub fn new>(path: P) -> rusqlite::Result { + Ok(ServerCertDB { conn: init_db(Connection::open(path)?)? }) } } @@ -54,13 +70,23 @@ impl Handler for ServerCertDB { type Result = Result<(), DBError>; fn handle(&mut self, msg: RegisterServer, _ctx: &mut Self::Context) -> Self::Result { - match self.db.get(&msg.cert) { - None => { - self.db.insert(msg.cert, msg.name); + match self.conn.query_row( + "SELECT name FROM servercert WHERE cert = ?1", + (&msg.cert,), + |row| row.get::<_, String>(0) + ) { + Ok(name) => { + Err(DBError::CertAlreadyRegistered(name)) + } + Err(rusqlite::Error::QueryReturnedNoRows) => { + self.conn.execute( + "INSERT INTO servercert (cert, name) VALUES (?1, ?2)", + (&msg.cert, &msg.name) + ).map_err(|_| DBError::GenericFailure)?; Ok(()) } - Some(name) => { - Err(DBError::CertAlreadyRegistered(name.clone())) + Err(_) => { + Err(DBError::GenericFailure) } } } @@ -70,7 +96,12 @@ impl Handler for ServerCertDB { type Result = bool; fn handle(&mut self, msg: IsNameRegistered, _ctx: &mut Self::Context) -> Self::Result { - self.db.values().any(|x| *x == msg.name) + let count: u64 = self.conn.query_row( + "SELECT COUNT(*) FROM servercert WHERE name = ?1", + (&msg.name,), + |row| row.get(0) + ).unwrap(); + count > 0 } } @@ -78,7 +109,21 @@ impl Handler for ServerCertDB { type Result = Option; fn handle(&mut self, msg: FetchName, _ctx: &mut Self::Context) -> Self::Result { - self.db.get(&msg.cert).map(|s| s.to_owned()) + match self.conn.query_row( + "SELECT name FROM servercert WHERE cert = ?1", + (&msg.cert,), + |row| row.get(0) + ) { + Ok(name) => { + Some(name) + } + Err(rusqlite::Error::QueryReturnedNoRows) => { + None + } + Err(e) => { + panic!("Error during FetchName: {}", e); + } + } } } @@ -86,7 +131,25 @@ impl Handler for ServerCertDB { type Result = Option; fn handle(&mut self, msg: UnregisterServer, _ctx: &mut Self::Context) -> Self::Result { - self.db.remove(&msg.cert) + match self.conn.query_row( + "SELECT name FROM servercert WHERE cert = ?1", + (&msg.cert,), + |row| row.get::<_, String>(0) + ) { + Ok(name) => { + self.conn.execute( + "DELETE FROM servercert WHERE cert = ?1", + (&msg.cert,) + ).unwrap(); + Some(name) + } + Err(rusqlite::Error::QueryReturnedNoRows) => { + None + } + Err(e) => { + panic!("Error during UnregisterServer: {}", e); + } + } } } @@ -96,13 +159,13 @@ mod tests { #[actix::test] async fn emptydb_isnameregistered() { - let servercert_db = ServerCertDB::new().start(); + let servercert_db = ServerCertDB::new_in_memory().unwrap().start(); assert!(!servercert_db.send(IsNameRegistered { name: "test".into() }).await.unwrap()); } #[actix::test] async fn emptyvec() { - let servercert_db = ServerCertDB::new().start(); + let servercert_db = ServerCertDB::new_in_memory().unwrap().start(); assert!(servercert_db.send(RegisterServer { cert: vec![], name: "test".into() }).await.unwrap().is_ok()); assert!(servercert_db.send(IsNameRegistered { name: "test".into() }).await.unwrap()); assert_eq!(servercert_db.send(FetchName { cert: vec![] }).await.unwrap().unwrap(), "test"); @@ -114,7 +177,7 @@ mod tests { #[actix::test] async fn normalcert() { - let servercert_db = ServerCertDB::new().start(); + let servercert_db = ServerCertDB::new_in_memory().unwrap().start(); let cert = vec![112, 111, 114, 99, 111, 100, 105, 111]; assert!(servercert_db.send(RegisterServer { cert: cert.clone(), name: "test2".into() }).await.unwrap().is_ok()); assert!(servercert_db.send(IsNameRegistered { name: "test2".into() }).await.unwrap()); @@ -127,7 +190,7 @@ mod tests { #[actix::test] async fn cert2_remains_after_delete_cert1() { - let servercert_db = ServerCertDB::new().start(); + let servercert_db = ServerCertDB::new_in_memory().unwrap().start(); let cert1 = vec![112, 111, 114, 99, 111, 100, 105, 111]; let cert2 = vec![67, 65, 78, 68, 69, 68, 73, 79]; assert!(servercert_db.send(RegisterServer { cert: cert1.clone(), name: "test3".into() }).await.unwrap().is_ok());