From a8b812abc396cbe6022933bfa372f60a6c488886 Mon Sep 17 00:00:00 2001 From: Rob Date: Sun, 11 Feb 2024 20:24:39 -0500 Subject: [PATCH] comments, more tests --- Cargo.lock | 72 +++++ broker/src/lib.rs | 82 +++--- broker/src/map.rs | 57 +++- broker/src/state.rs | 385 +++++++++++++++++++++----- proto/Cargo.toml | 3 +- proto/src/connection/protocols/mod.rs | 7 +- 6 files changed, 498 insertions(+), 108 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c17864b..1e53ebf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -768,6 +768,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "downcast-rs" version = "1.2.0" @@ -812,6 +818,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "futures" version = "0.3.30" @@ -1204,6 +1216,33 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockall" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43766c2b5203b10de348ffe19f7e54564b64f3d6018ff7648d1e2d6d3a0f0a48" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cbce79ec385a1d4f54baa90a76401eb15d9cab93685f62e7e9f942aa00ae2" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "neli" version = "0.6.4" @@ -1424,6 +1463,32 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "predicates" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro2" version = "1.0.78" @@ -1442,6 +1507,7 @@ dependencies = [ "capnp", "capnpc", "jf-primitives", + "mockall", "pem", "quinn", "rand", @@ -1982,6 +2048,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "thiserror" version = "1.0.56" diff --git a/broker/src/lib.rs b/broker/src/lib.rs index aab48fe..de792e8 100644 --- a/broker/src/lib.rs +++ b/broker/src/lib.rs @@ -1,7 +1,8 @@ //! This file contains the implementation of the `Broker`, which routes messages //! for the Push CDN. -// TODO: massive cleanup on this file +// TODO: split out this file into multiple files. +// TODO: logging mod map; mod state; @@ -25,8 +26,7 @@ use proto::{ redis::{self, BrokerIdentifier}, verify_broker, }; -use slotmap::DefaultKey; -use state::ConnectionLookup; +use state::{ConnectionId, ConnectionLookup, Sender}; use tokio::{select, spawn, sync::RwLock, time::sleep}; use tracing::{error, info, warn}; @@ -96,20 +96,27 @@ struct Inner< pd: PhantomData<(UserProtocolType, BrokerProtocolType, UserSignatureScheme)>, } +/// This macro is a helper macro that lets us "send many messages", and remove +/// the actor from the local state if the message failed to send macro_rules! send_or_remove_many { ($connections: expr, $lookup:expr, $message: expr, $position: expr) => { + // For each connection, for connection in $connections { + // Queue a message back if connection .1 .queue_message($message.clone(), $position) .is_err() { + // If it fails, remove the connection. get_lock!($lookup, write).remove_connection(connection.0); }; } }; } +/// We use this macro to help send direct messages. It just makes the code easier +/// to look at. macro_rules! send_direct { ($lookup: expr, $key: expr, $message: expr) => {{ let connections = $lookup.read().await.get_connections_by_key(&$key).clone(); @@ -117,6 +124,8 @@ macro_rules! send_direct { }}; } +/// We use this macro to help send broadcast messages. It just makes the code easier +/// to look at. macro_rules! send_broadcast { ($lookup:expr, $topics: expr, $message: expr) => {{ let connections = $lookup @@ -128,6 +137,7 @@ macro_rules! send_broadcast { }}; } +/// This is a macro to acquire an async lock, which helps readability. macro_rules! get_lock { ($lock :expr, $type: expr) => { paste::item! { @@ -136,6 +146,7 @@ macro_rules! get_lock { }; } +// Creates and serializes a new message of the specified type with the specified data. macro_rules! new_serialized_message { ($type: ident, $data: expr) => { Arc::>::from(bail!( @@ -487,6 +498,10 @@ where } } + /// This task deals with sending connected user and topic updates to other brokers. It takes advantage of + /// `SnapshotMap`, so we can send partial or full updates to the other brokers as they need it. + /// Right now, we do it every 5 seconds, or on every user connect if the number of connections is + /// sufficiently low. async fn send_updates_task( inner: Arc< Inner, @@ -510,6 +525,8 @@ where } } + /// This task deals with setting the number of our connected users in `Redis`. It allows + /// the marshal to correctly choose the broker with the least amount of connections. async fn heartbeat_task( inner: Arc< Inner, @@ -576,8 +593,12 @@ where } } +/// This is a macro that helps us send an update to other brokers. The message type depends on +/// whether it is a user update or a topic update. The recipients is the other brokers (or broker) +/// for which we want the partial/complete update, and the position refers to the position the message +/// should go in the queue. macro_rules! send_update_to_brokers { - ($lookup:expr, $message_type: ident, $data:expr, $recipients: expr, $position: ident) => {{ + ($self:expr, $message_type: ident, $data:expr, $recipients: expr, $position: ident) => {{ // If the data is not empty, make a message of the specified type if !$data.is_empty() { // Create a `Subscribe` message, which contains the full list of topics we're subscribed to @@ -586,7 +607,7 @@ macro_rules! send_update_to_brokers { // For each recipient, send to the destined position in the queue send_or_remove_many!( $recipients, - $lookup, + $self.broker_connection_lookup, message, Position::$position ); @@ -606,9 +627,11 @@ where BrokerSignatureScheme::VerificationKey: Serializable, BrokerSignatureScheme::SigningKey: Serializable, { + /// This is the main loop where we deal with broker connectins. On exit, the calling function + /// should remove the broker from the map. pub async fn broker_recv_loop( self: &Arc, - connection_id: DefaultKey, + connection_id: ConnectionId, mut receiver: BrokerProtocolType::Receiver, ) -> Result<()> { while let Ok(message) = receiver.recv_message().await { @@ -657,9 +680,11 @@ where Err(Error::Connection("connection closed".to_string())) } + /// This is the main loop where we deal with user connectins. On exit, the calling function + /// should remove the user from the map. pub async fn user_recv_loop( self: &Arc, - connection_id: DefaultKey, + connection_id: ConnectionId, mut receiver: UserProtocolType::Receiver, ) { while let Ok(message) = receiver.recv_message().await { @@ -703,10 +728,13 @@ where } } + /// This function lets us send updates to brokers on demand. We need this to ensure consistency between brokers + /// (e.g. which brokers have which users connected). We send these updates out periodically, but also + /// on every user join if the number of connected users is sufficiently small. pub async fn send_updates_to_brokers( self: &Arc, - full: Vec<(DefaultKey, Arc>)>, - partial: Vec<(DefaultKey, Arc>)>, + full: Vec<(ConnectionId, Sender)>, + partial: Vec<(ConnectionId, Sender)>, ) -> Result<()> { // When a broker connects, we have to send: // 1. Our snapshot to the new broker (of what topics/users we're subscribed for) @@ -720,26 +748,14 @@ where // Send the full connected users to interested brokers first in the queue (so that it is the correct order) // TODO: clean up this function - send_update_to_brokers!( - self.broker_connection_lookup, - UsersConnected, - key_snapshot.snapshot, - &full, - Front - ); + send_update_to_brokers!(self, UsersConnected, key_snapshot.snapshot, &full, Front); // Send the full topics list to interested brokers first in the queue (so that it is the correct order) - send_update_to_brokers!( - self.broker_connection_lookup, - Subscribe, - topic_snapshot.snapshot, - &full, - Front - ); + send_update_to_brokers!(self, Subscribe, topic_snapshot.snapshot, &full, Front); // Send the insertion updates for keys, if any send_update_to_brokers!( - self.broker_connection_lookup, + self, UsersConnected, key_snapshot.insertions, &partial, @@ -748,7 +764,7 @@ where // Send the removal updates for keys, if any send_update_to_brokers!( - self.broker_connection_lookup, + self, UsersDisconnected, key_snapshot.removals, &partial, @@ -756,22 +772,10 @@ where ); // Send the insertion updates for topics, if any - send_update_to_brokers!( - self.broker_connection_lookup, - Subscribe, - topic_snapshot.insertions, - &partial, - Back - ); + send_update_to_brokers!(self, Subscribe, topic_snapshot.insertions, &partial, Back); // Send the removal updates for topics, if any - send_update_to_brokers!( - self.broker_connection_lookup, - Unsubscribe, - topic_snapshot.removals, - &partial, - Back - ); + send_update_to_brokers!(self, Unsubscribe, topic_snapshot.removals, &partial, Back); Ok(()) } diff --git a/broker/src/map.rs b/broker/src/map.rs index 35b158a..8fd6ad6 100644 --- a/broker/src/map.rs +++ b/broker/src/map.rs @@ -1,3 +1,9 @@ +//! This is where we define the `SnapShotMap` implementation, a struct +//! that allows us to get full and partial updates on the list of keys in a map. +//! +//! We use this in broker <-> broker communication to save bandwidth, wherein +//! we only need to send partial updates over the wire. + use std::{ cmp::Ordering, collections::HashMap, @@ -7,25 +13,43 @@ use std::{ use delegate::delegate; +/// A primitive that allows us to get full and partial updates on the list +/// of keys in a map. It's a write-ahead-log that automatically prunes. This is helpful +/// for broker <-> broker communication, where we don't want to send the whole list every time. pub struct SnapshotMap { + /// The previous snapshot, which is moved out when we calculate the difference. snapshot: Vec, + + /// The log of operations, which we sum up and run calculations over to determine + /// the actual (delta) difference. For example, if we have `Add(User(1)), Remove(User(1)) Add(User(1))`, + /// the calculated output will be just `Add(User(1))`. log: Vec>, + + /// The actual underlying `HashMap`, which contains the data. inner: HashMap, } -#[derive(Debug)] +/// Represents an action taken on the inner map. pub enum Operation { + /// An item was inserted to the map. Insert(K), + /// An item was removed from the map. Remove(K), } +/// The actual snapshot. Contains both the previous snapshot and +/// a list of insertions and removals _since_ that snapshot. pub struct SnapshotWithChanges { + /// The previous snapshot pub snapshot: Vec, + /// Key insertions since the previous snapshot pub insertions: Vec, + /// Key removals since the previous snapshot pub removals: Vec, } impl SnapshotMap { + /// Create a new `SnapshotMap`. pub fn new() -> Self { Self { log: Vec::new(), @@ -34,7 +58,11 @@ impl SnapshotMap { } } + /// Insert an item into the `SnapshotMap`, returning the + /// old value if there was one. Under the hood, it logs + /// an `Insert()` operation. pub fn insert(&mut self, key: K, val: V) -> Option { + // Insert the value, saving to return for later. let res = self.inner.insert(key.clone(), val); if res.is_none() { @@ -45,7 +73,11 @@ impl SnapshotMap { res } + /// Remove an item from the `SnapshotMap`, returning the + /// removed value if there was one. Under the hood, we add a + /// `Remove` operation to the log as well. pub fn remove(&mut self, key: &K) -> Option { + // Remove the value, saving the output for later. let res = self.inner.remove(key); if res.is_some() { @@ -56,9 +88,14 @@ impl SnapshotMap { res } + /// Calculates the difference between the last time this was called. Returns + /// a `SnapshotWithChanges`, which is both the old snapshot and a list of changes + /// since the current one. pub fn difference(&mut self) -> SnapshotWithChanges { + // Take our inner logs, replacing with nothing. let logs = std::mem::take(&mut self.log); + // Count the amount of each log for a particular key. let mut changes = HashMap::new(); for log in logs { match log { @@ -67,22 +104,29 @@ impl SnapshotMap { } } + // Check the number of insertions and removals, and prune the log based on that. + // This only works because we do not log an event for a key that was unchanged. let mut insertions = Vec::new(); let mut removals = Vec::new(); for change in changes { + // Compare to 0. match change.1.cmp(&0) { Ordering::Greater => { + // If we are bigger than zero, we ended on an insertion. So make this part of the pruned log. insertions.push(change.0); } Ordering::Less => { + // If we are less than zero, we ended on an removal. So make this part of the pruned log. removals.push(change.0); } + // If we are zero, do nothing Ordering::Equal => {} } } + // Replace the snapshot with the current data and return the insertions and removals. SnapshotWithChanges { snapshot: std::mem::replace(&mut self.snapshot, self.inner.keys().cloned().collect()), insertions, @@ -90,6 +134,8 @@ impl SnapshotMap { } } + // We use this to delegate `get()` and `get_mut()` methods to the lower `HashMap`, as we don't + // actually need to log those events. delegate! { to self.inner { pub fn get(&self, value: &K) -> Option<&V>; @@ -102,8 +148,12 @@ impl SnapshotMap { pub mod test { use super::SnapshotMap; + /// This test is supposed to test various cases for the difference calculation, which is the meat + /// of the `SnapshotMap`. #[test] fn test_snapshot_difference_calculation() { + // Make sure that Add(1), Remove(1), Add(1) just prunes to `Add(1)`, and + // has no removals. let mut map: SnapshotMap = SnapshotMap::new(); map.insert(1, 0); map.remove(&1); @@ -113,6 +163,8 @@ pub mod test { assert!(difference.insertions == Vec::from(vec![1])); assert!(difference.removals.is_empty()); + // Make sure that Remove(1), Remove(1), Add(1) _AFTER_ the previous operation just prunes + // away and has no removals. The snapshot should be the last value, which was 1. map.remove(&1); map.remove(&1); map.insert(1, 0); @@ -122,6 +174,7 @@ pub mod test { assert!(difference.removals.is_empty()); assert!(difference.snapshot == vec![1]); + // Insert -> Remove -> Remove, make sure the only difference is the removal. map.insert(1, 0); map.remove(&1); map.remove(&1); @@ -131,6 +184,8 @@ pub mod test { assert!(difference.removals == Vec::from(vec![1])); assert!(difference.snapshot == vec![1]); + // At the last snapshot, we removed the last item. So let's make sure there is nothing + // in this current snapshot. let difference = map.difference(); assert!(difference.insertions.is_empty()); assert!(difference.removals.is_empty()); diff --git a/broker/src/state.rs b/broker/src/state.rs index 57d6842..fd29026 100644 --- a/broker/src/state.rs +++ b/broker/src/state.rs @@ -11,8 +11,19 @@ use slotmap::{DefaultKey, DenseSlotMap}; use crate::map::{SnapshotMap, SnapshotWithChanges}; +/// This helps with readability, it just defines a sender +pub type Sender = Arc>; + +/// These also help with readability. +pub type ConnectionId = DefaultKey; +pub type UserPublicKey = Vec; + +/// This macro is basically `.entry().unwrap_or().extend()`. We need this +/// because it allows us to use our fancy `SnapshotVec` which doesn't implement +/// `entry()`. macro_rules! extend { ($lookup: expr, $key: expr, $values: expr) => {{ + // Look up a value if let Some(values) = $lookup.get_mut(&$key) { // If it exists, extend it values.extend($values); @@ -23,37 +34,55 @@ macro_rules! extend { }}; } +// This macro helps us remove a connection from a particular map. +// It makes the code easier to look at. macro_rules! remove_connection_from { ($connection_id:expr, $field:expr, $map: expr) => {{ // For each set, remove the connection ID from it for item in $field { // Get the set, expecting it to exist - let relevant_connection_ids = $map.get_mut(&item).expect("object to exist"); - - // Remove our connection ID - relevant_connection_ids.remove(&$connection_id); - // Remove the topic if it's empty - if relevant_connection_ids.is_empty() { - $map.remove(&item); + if let Some(connection_ids) = $map.get_mut(&item) { + // Remove our connection ID + connection_ids.remove(&$connection_id); + // Remove the topic if it's empty + if connection_ids.is_empty() { + $map.remove(&item); + } } } }}; } +// This is a light wrapper around a connection, which we use to facilitate +// removing from different parts of our state. For example, we store +// `keys`, as "which users are connected to us". When we remove a connection, +// we want to make sure the connection is not pointing to that key any more. struct Connection { - inner: Arc>, - keys: HashSet>, + // The actual connection (sender) + inner: Sender, + // A list of public keys that the sender is linked to + keys: HashSet, + // A list of topics that the sender is linked to topics: HashSet, - id: DefaultKey, + // The stable ID for the connection. + id: ConnectionId, } /// `ConnectionLookup` is what we use as a broker to "look up" where messages are supposed /// to be directed to. pub struct ConnectionLookup { - connections: DenseSlotMap>, - - key_to_connection_ids: SnapshotMap, HashSet>, - topic_to_connection_ids: SnapshotMap>, + /// This `DenseSlotMap` is where we insert the actual connection (indexed by the connection "key"). + /// Slotted maps are basically `HashMaps`, but we only care about using it to index _another_ map, + /// so the slotted map will give us a value to use on insert. Pretty cool. + connections: DenseSlotMap>, + + /// This is where we store the information on how public keys map to some connection IDs. + /// It uses our fancy `SnapshotMap`, which can return both a list of updates and a full + /// set if necessary. + key_to_connection_ids: SnapshotMap>, + /// This is where we store the information on which topics particular connections care about. + /// It also uses the `SnapshotMap`. + topic_to_connection_ids: SnapshotMap>, } impl Default for ConnectionLookup { @@ -69,29 +98,47 @@ impl Default for ConnectionLookup { } impl ConnectionLookup { + /// Returns an empty `ConnectionLookup` + pub fn new() -> Self { + Self::default() + } + + /// Get the number of connections currently in the map. We use this to + /// report to `Redis`, so the marshal knows who has the least connections. pub fn get_connection_count(&self) -> usize { self.connections.len() } - pub fn get_key_updates_since(&mut self) -> SnapshotWithChanges> { + /// This is a proxy function to the `SnapshotMap`. It lets us get the difference + /// in users connected since last time we called it. + pub fn get_key_updates_since(&mut self) -> SnapshotWithChanges { // Get the difference since last call self.key_to_connection_ids.difference() } + /// This is a proxy function to the `SnapshotMap`. It lets us get the difference + /// in topics we care about since the last time we called it. pub fn get_topic_updates_since(&mut self) -> SnapshotWithChanges { // Get the difference since last call self.topic_to_connection_ids.difference() } - pub fn get_all_connections(&self) -> Vec<(DefaultKey, Arc>)> { - // Iterate and collect every connection, cloning it + /// This lets us get an iterator over all of the connections, along with their unique identifier. + /// We need this to send information to _all_ brokers when a broker connects. + /// + /// It returns a `Vec<(connection id, sender)>`. + /// TODO: type alias + pub fn get_all_connections(&self) -> Vec<(ConnectionId, Sender)> { + // Iterate and collect every connection, cloning the necessary values. self.connections .values() .map(|conn| (conn.id, conn.inner.clone())) .collect() } - pub fn add_connection(&mut self, connection: Arc>) -> DefaultKey { + /// Adds a connection to the state. It returns a key (the connection ID) with which + /// we can use later when we want to reference that connection. + pub fn add_connection(&mut self, connection: Sender) -> ConnectionId { // Add the connection with no keys and no topics self.connections.insert_with_key(|id| Connection { inner: connection, @@ -101,10 +148,12 @@ impl ConnectionLookup { }) } + /// Removes a connection from the state. We insert the key we got during the insertion process, + /// and we get back the connection if there was one. pub fn remove_connection( &mut self, - connection_id: DefaultKey, - ) -> Option>> { + connection_id: ConnectionId, + ) -> Option> { // Remove a possible connection by its ID let possible_connection = self.connections.remove(connection_id); @@ -125,9 +174,11 @@ impl ConnectionLookup { possible_connection.map(|conn| conn.inner) } + /// Subscribes a connection to a list of topics, given both the + /// connection ID and the list of topics. pub fn subscribe_connection_id_to_topics( &mut self, - connection_id: DefaultKey, + connection_id: ConnectionId, topics: Vec, ) { // Get the connection, if it exists @@ -145,10 +196,35 @@ impl ConnectionLookup { } } + /// Unsubscribe a connection from all topics in the given list. + pub fn unsubscribe_connection_id_from_topics( + &mut self, + connection_id: ConnectionId, + topics: Vec, + ) { + // Get the connection, if it exists + let possible_connection = self.connections.get_mut(connection_id); + + // If the connection exists: + if let Some(connection) = possible_connection { + // Remove the connection from the topic + // Remove the topic if it exists + remove_connection_from!(connection_id, &topics, self.topic_to_connection_ids); + + // Remove the topic from the connection + for topic in topics { + connection.topics.remove(&topic); + } + } + } + + /// Subscribes a connection to a list of user keys, given both the + /// connection ID and the list of keys. We use this on the broker to bulk subscribe + /// for a bunch of keys, and on the user side to subscribe to one only. pub fn subscribe_connection_id_to_keys( &mut self, - connection_id: DefaultKey, - keys: Vec>, + connection_id: ConnectionId, + keys: Vec, ) { // Get the connection, if it exists let possible_connection = self.connections.get_mut(connection_id); @@ -165,10 +241,34 @@ impl ConnectionLookup { } } + /// Unsubscribe a connection ID from all keys in the list./ + pub fn unsubscribe_connection_id_from_keys( + &mut self, + connection_id: ConnectionId, + keys: Vec, + ) { + // Get the connection, if it exists + let possible_connection = self.connections.get_mut(connection_id); + + // If the connection exists: + if let Some(connection) = possible_connection { + // Remove the connection from the topic + // Remove the topic if it exists + remove_connection_from!(connection_id, &keys, self.key_to_connection_ids); + + // Remove the topic from the connection + for key in keys { + connection.keys.remove(&key); + } + } + } + + /// Aggregate connections over a list of user keys. This is used to send messages + /// to users with the corresponding key (generally direct). pub fn get_connections_by_key( &self, - key: &Vec, - ) -> Vec<(DefaultKey, Arc>)> { + key: &UserPublicKey, + ) -> Vec<(ConnectionId, Sender)> { // We return this at the end let mut connections = Vec::new(); @@ -192,10 +292,12 @@ impl ConnectionLookup { connections } + /// Aggregate connections over a list of topics. This is used to send messages + /// to users who are subscribed to a particular topic. pub fn get_connections_by_topic( &self, topics: Vec, - ) -> Vec<(DefaultKey, Arc>)> { + ) -> Vec<(ConnectionId, Sender)> { // We return this at the end let mut connections = Vec::new(); @@ -208,6 +310,7 @@ impl ConnectionLookup { .unwrap_or(&HashSet::new()) { // Get the connection, clone its inner, add it to the running vec + // TODO: remove these expects connections.push(( *connection_id, self.connections @@ -219,52 +322,202 @@ impl ConnectionLookup { } } - // Return + // Return the connections connections } +} - pub fn unsubscribe_connection_id_from_topics( - &mut self, - connection_id: DefaultKey, - topics: Vec, - ) { - // Get the connection, if it exists - let possible_connection = self.connections.get_mut(connection_id); - - // If the connection exists: - if let Some(connection) = possible_connection { - // Remove the connection from the topic - // Remove the topic if it exists - remove_connection_from!(connection_id, &topics, self.topic_to_connection_ids); +#[cfg(test)] +pub mod test { + use std::time::Duration; + + use proto::connection::protocols::{MockProtocol, MockSender}; + + use super::*; + + /// A helper macro for mocking a connection. We use this because I didn't want to + /// have a lot of extra code to fake implement the connection trait. + macro_rules! mock_connection { + () => { + Arc::new(BatchedSender::from( + MockSender::new(), + Duration::from_secs(1), + 1200, + )) + }; + } - // Remove the topic from the connection - for topic in topics { - connection.topics.remove(&topic); - } - } + /// Here is where we test `insert` and `remove` operations for our state map. + /// We also test subscriptions to both keys and topics. + /// TODO: I want to add a lot more tests to this. + #[tokio::test] + async fn test_insert_remove() { + // Mock map + let mut lookup = ConnectionLookup::::new(); + let connection = mock_connection!(); + + // Count check + assert!(lookup.get_connection_count() == 0); + let id1 = lookup.add_connection(connection.clone()); + assert!(lookup.get_connection_count() == 1); + let id2 = lookup.add_connection(connection.clone()); + assert!(lookup.get_connection_count() == 2); + + // Remove check + lookup.remove_connection(id1); + lookup.remove_connection(id2); + + assert!(lookup.get_connection_count() == 0); } - pub fn unsubscribe_connection_id_from_keys( - &mut self, - connection_id: DefaultKey, - keys: Vec>, - ) { - // Get the connection, if it exists - let possible_connection = self.connections.get_mut(connection_id); - - // If the connection exists: - if let Some(connection) = possible_connection { - // Remove the connection from the topic - // Remove the topic if it exists - remove_connection_from!(connection_id, &keys, self.key_to_connection_ids); + /// Here is where we test subscriptions/unsubscriptions + #[tokio::test] + async fn test_subscribe_unsubscribe_key() { + // Mock map + let mut lookup = ConnectionLookup::::new(); + + let connection = mock_connection!(); + let id1 = lookup.add_connection(connection.clone()); + let id2 = lookup.add_connection(connection.clone()); + + // Key subscription check + lookup.subscribe_connection_id_to_keys(id1, vec![vec![0], vec![1]]); + let connections: Vec = lookup + .get_connections_by_key(&vec![0]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id1]); + + let connections: Vec = lookup + .get_connections_by_key(&vec![1]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id1]); + + // Subscribe key2 + lookup.subscribe_connection_id_to_keys(id2, vec![vec![1]]); + let connections: Vec = lookup + .get_connections_by_key(&vec![1]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections.contains(&id1)); + assert!(connections.contains(&id2)); + assert!(connections.len() == 2); + + // Check that we're not subscribed to the one we didn't subscribe to + let connections: Vec = lookup + .get_connections_by_key(&vec![0]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id1]); + + // Unsubscribe key1. Should just be id2. + lookup.unsubscribe_connection_id_from_keys(id1, vec![vec![1]]); + let connections: Vec = lookup + .get_connections_by_key(&vec![1]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id2]); + + // Remove id2, should be auto-unsubscribed + // TODO: macro this + lookup.remove_connection(id2); + let connections: Vec = lookup + .get_connections_by_key(&vec![1]) + .iter() + .map(|conn| conn.0) + .collect(); + + assert!(connections == vec![]); + + lookup.unsubscribe_connection_id_from_keys(id1, vec![vec![0]]); + let connections: Vec = lookup + .get_connections_by_key(&vec![0]) + .iter() + .map(|conn| conn.0) + .collect(); + + assert!(connections == vec![]); + } - // Remove the topic from the connection - for key in keys { - connection.keys.remove(&key); - } - } + /// Here is where we test subscriptions/unsubscriptions + #[tokio::test] + async fn test_subscribe_unsubscribe_topic() { + // Mock map + let mut lookup = ConnectionLookup::::new(); + + let connection = mock_connection!(); + let id1 = lookup.add_connection(connection.clone()); + let id2 = lookup.add_connection(connection.clone()); + + // Key subscription check + lookup.subscribe_connection_id_to_topics(id1, vec![Topic::Global, Topic::DA]); + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::Global]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id1]); + + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::DA]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id1]); + + // Subscribe key2 + lookup.subscribe_connection_id_to_topics(id2, vec![Topic::DA]); + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::DA]) + .iter() + .map(|conn| conn.0) + .collect(); + // TODO: write assert for this + assert!(connections.contains(&id1)); + assert!(connections.contains(&id2)); + assert!(connections.len() == 2); + + // Check that we're not subscribed to the one we didn't subscribe to + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::Global]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id1]); + + // Unsubscribe key1. Should just be id2. + lookup.unsubscribe_connection_id_from_topics(id1, vec![Topic::DA]); + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::DA]) + .iter() + .map(|conn| conn.0) + .collect(); + assert!(connections == vec![id2]); + + // Remove id2, should be auto-unsubscribed + // TODO: macro this + lookup.remove_connection(id2); + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::DA]) + .iter() + .map(|conn| conn.0) + .collect(); + + assert!(connections == vec![]); + + lookup.unsubscribe_connection_id_from_topics(id1, vec![Topic::Global]); + let connections: Vec = lookup + .get_connections_by_topic(vec![Topic::Global]) + .iter() + .map(|conn| conn.0) + .collect(); + + assert!(connections == vec![]); } } - -#[cfg(test)] -pub mod test {} diff --git a/proto/Cargo.toml b/proto/Cargo.toml index 96af9b7..e4d45c8 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -25,4 +25,5 @@ async-trait = "0.1.77" rustls = "0.21.10" rcgen = "0.12.0" pem = "3.0.3" -redis = { version = "0.24.0", features = ["tokio-comp", "connection-manager"] } \ No newline at end of file +redis = { version = "0.24.0", features = ["tokio-comp", "connection-manager"] } +mockall = "0.12.1" \ No newline at end of file diff --git a/proto/src/connection/protocols/mod.rs b/proto/src/connection/protocols/mod.rs index 075050e..82652dd 100644 --- a/proto/src/connection/protocols/mod.rs +++ b/proto/src/connection/protocols/mod.rs @@ -3,6 +3,7 @@ use std::{collections::VecDeque, net::SocketAddr, sync::Arc}; use async_trait::async_trait; +use mockall::automock; use crate::{error::Result, message::Message}; pub mod quic; @@ -13,6 +14,7 @@ pub mod tcp; const _: [(); 0 - (!(usize::BITS >= u64::BITS)) as usize] = []; /// The `Protocol` trait lets us be generic over a connection type (Tcp, Quic, etc). +#[automock(type Sender=MockSender; type Receiver=MockReceiver; type Listener=MockListener;)] #[async_trait] pub trait Protocol: Send + Sync + 'static { // TODO: make these generic over reader/writer @@ -39,6 +41,7 @@ pub trait Protocol: Send + Sync + 'static { ) -> Result; } +#[automock] #[async_trait] pub trait Sender { /// Send a message over the connection. @@ -61,6 +64,7 @@ pub trait Sender { async fn finish(&mut self) -> Result<()>; } +#[automock] #[async_trait] pub trait Receiver { /// Receives a single message over the stream and deserializes @@ -79,8 +83,9 @@ pub trait Receiver { async fn recv_message_raw(&mut self) -> Result>; } +#[automock] #[async_trait] -pub trait Listener { +pub trait Listener { /// Accept a connection from the local, bound socket. /// Returns a connection or an error if we encountered one. ///