diff --git a/cdn-broker/src/handlers/broker.rs b/cdn-broker/src/handlers/broker.rs index 71e7375..46663d4 100644 --- a/cdn-broker/src/handlers/broker.rs +++ b/cdn-broker/src/handlers/broker.rs @@ -30,6 +30,13 @@ impl Inner { ), is_outbound: bool, ) { + // Acquire a permit to authenticate with a broker. Removes the possibility for race + // conditions when doing so. + let Ok(auth_guard) = self.auth_lock.acquire().await else { + error!("needed semaphore has been closed"); + std::process::exit(-1); + }; + // Depending on which way the direction came in, we will want to authenticate with a different // flow. let broker_identifier = if is_outbound { @@ -55,6 +62,9 @@ impl Inner { self.connections .add_broker(broker_identifier.clone(), sender); + // Once we have added the broker, drop the authentication guard + drop(auth_guard); + // Send a full user sync if let Err(err) = self.full_user_sync(&broker_identifier) { error!("failed to perform full user sync: {err}"); diff --git a/cdn-broker/src/lib.rs b/cdn-broker/src/lib.rs index e62d260..791bf1c 100644 --- a/cdn-broker/src/lib.rs +++ b/cdn-broker/src/lib.rs @@ -35,7 +35,7 @@ use cdn_proto::{ use connections::Connections; use derive_builder::Builder; use local_ip_address::local_ip; -use tokio::{select, spawn}; +use tokio::{select, spawn, sync::Semaphore}; use tracing::info; /// The broker's configuration. We need this when we create a new one. @@ -90,6 +90,10 @@ struct Inner { /// against the stake table. keypair: KeyPair, + /// A lock on authentication so we don't thrash when authenticating with brokers. + /// Only lets us authenticate to one broker at a time. + auth_lock: Semaphore, + /// The connections that currently exist. We use this everywhere we need to update connection /// state or send messages. connections: Arc>, @@ -209,6 +213,7 @@ impl Broker { discovery_client, identity: identity.clone(), keypair, + auth_lock: Semaphore::const_new(1), connections: Arc::from(Connections::new(identity)), }), metrics_bind_address, diff --git a/cdn-broker/src/tasks/heartbeat.rs b/cdn-broker/src/tasks/heartbeat.rs index 19b2344..89875a0 100644 --- a/cdn-broker/src/tasks/heartbeat.rs +++ b/cdn-broker/src/tasks/heartbeat.rs @@ -2,7 +2,12 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; -use cdn_proto::{connection::protocols::Protocol, def::RunDef, discovery::DiscoveryClient}; +use cdn_proto::{ + connection::protocols::Protocol, + def::RunDef, + discovery::{BrokerIdentifier, DiscoveryClient}, +}; +use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; use tokio::{spawn, time::sleep}; use tracing::{error, warn}; @@ -29,10 +34,18 @@ impl Inner { // Check for new brokers, spawning tasks to connect to them if necessary match discovery_client.get_other_brokers().await { Ok(brokers) => { + // Calculate which brokers to connect to by taking the difference + let mut brokers_to_connect_to: Vec = brokers + .difference(&HashSet::from_iter(self.connections.all_brokers())) + .cloned() + .collect(); + + // Shuffle the list (so we don't get stuck in the authentication lock + // on a broker that is down) + brokers_to_connect_to.shuffle(&mut StdRng::from_entropy()); + // Calculate the difference, spawn tasks to connect to them - for broker in - brokers.difference(&HashSet::from_iter(self.connections.all_brokers())) - { + for broker in brokers_to_connect_to { // TODO: make this into a separate function // Extrapolate the address to connect to let to_connect_address = broker.private_advertise_address.clone();