sui_core/traffic_controller/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4pub mod metrics;
5pub mod nodefw_client;
6pub mod nodefw_test_server;
7pub mod policies;
8
9use dashmap::DashMap;
10use fs::File;
11use mysten_common::fatal;
12use prometheus::IntGauge;
13use std::fs;
14use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15use std::ops::Add;
16use std::sync::Arc;
17use sui_types::error::{SuiError, SuiErrorKind};
18
19use self::metrics::TrafficControllerMetrics;
20use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses, NodeFWClient};
21use crate::traffic_controller::policies::{
22    Policy, PolicyResponse, TrafficControlPolicy, TrafficTally,
23};
24use mysten_metrics::spawn_monitored_task;
25use parking_lot::Mutex as ParkingLotMutex;
26use rand::Rng;
27use std::fmt::Debug;
28use std::time::{Duration, Instant, SystemTime};
29use sui_types::traffic_control::{
30    PolicyConfig, PolicyType, RemoteFirewallConfig, TrafficControlReconfigParams, Weight,
31};
32use tokio::sync::mpsc::error::TrySendError;
33use tokio::sync::{Mutex, RwLock, mpsc};
34use tracing::{debug, error, info, trace, warn};
35
36pub const METRICS_INTERVAL_SECS: u64 = 2;
37pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 300;
38
39type Blocklist = Arc<DashMap<IpAddr, SystemTime>>;
40
41#[derive(Clone)]
42pub struct Blocklists {
43    clients: Blocklist,
44    proxied_clients: Blocklist,
45}
46
47#[derive(Clone)]
48pub enum Acl {
49    Blocklists(Blocklists),
50    /// If this variant is set, then we do no tallying or running
51    /// of background tasks, and instead simply block all IPs not
52    /// in the allowlist on calls to `check`. The allowlist should
53    /// only be populated once at initialization.
54    Allowlist(Vec<IpAddr>),
55}
56
57#[derive(Clone)]
58pub struct TrafficController {
59    tally_channel: Arc<ParkingLotMutex<Option<mpsc::Sender<TrafficTally>>>>,
60    acl: Acl,
61    metrics: Arc<TrafficControllerMetrics>,
62    spam_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
63    error_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
64    policy_config: Arc<RwLock<PolicyConfig>>,
65    fw_config: Option<RemoteFirewallConfig>,
66}
67
68impl Debug for TrafficController {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        // NOTE: we do not want to print the contents of the blocklists to logs
71        // given that (1) it contains all requests IPs, and (2) it could be quite
72        // large. Instead, we print lengths of the blocklists. Further, we prefer
73        // to get length from the metrics rather than from the blocklists themselves
74        // to avoid unneccesarily aquiring the read lock.
75        f.debug_struct("TrafficController")
76            .field(
77                "connection_ip_blocklist_len",
78                &self.metrics.connection_ip_blocklist_len.get(),
79            )
80            .field(
81                "proxy_ip_blocklist_len",
82                &self.metrics.proxy_ip_blocklist_len.get(),
83            )
84            .finish()
85    }
86}
87
88impl TrafficController {
89    pub async fn init(
90        policy_config: PolicyConfig,
91        metrics: Arc<TrafficControllerMetrics>,
92        fw_config: Option<RemoteFirewallConfig>,
93    ) -> Self {
94        metrics.dry_run_enabled.set(policy_config.dry_run as i64);
95        match policy_config.allow_list.clone() {
96            Some(allow_list) => {
97                let allowlist = allow_list
98                    .into_iter()
99                    .map(|ip_str| {
100                        parse_ip(&ip_str).unwrap_or_else(|| {
101                            fatal!("Failed to parse allowlist IP address: {:?}", ip_str)
102                        })
103                    })
104                    .collect();
105                Self {
106                    tally_channel: Arc::new(ParkingLotMutex::new(None)),
107                    acl: Acl::Allowlist(allowlist),
108                    metrics,
109                    policy_config: Arc::new(RwLock::new(policy_config)),
110                    fw_config,
111                    spam_policy: None,
112                    error_policy: None,
113                }
114            }
115            None => {
116                let spam_policy = Arc::new(Mutex::new(
117                    TrafficControlPolicy::from_spam_config(policy_config.clone()).await,
118                ));
119                let error_policy = Arc::new(Mutex::new(
120                    TrafficControlPolicy::from_error_config(policy_config.clone()).await,
121                ));
122                let this = Self {
123                    tally_channel: Arc::new(ParkingLotMutex::new(None)),
124                    acl: Acl::Blocklists(Blocklists {
125                        clients: Arc::new(DashMap::new()),
126                        proxied_clients: Arc::new(DashMap::new()),
127                    }),
128                    metrics,
129                    policy_config: Arc::new(RwLock::new(policy_config)),
130                    fw_config,
131                    spam_policy: Some(spam_policy),
132                    error_policy: Some(error_policy),
133                };
134                this.spawn().await;
135                this
136            }
137        }
138    }
139
140    pub async fn init_for_test(
141        policy_config: PolicyConfig,
142        fw_config: Option<RemoteFirewallConfig>,
143    ) -> Self {
144        let metrics = Arc::new(TrafficControllerMetrics::new(&prometheus::Registry::new()));
145        Self::init(policy_config, metrics, fw_config).await
146    }
147
148    async fn spawn(&self) {
149        let policy_config = { self.policy_config.read().await.clone() };
150        Self::set_policy_config_metrics(&policy_config, self.metrics.clone());
151        let (tx, rx) = mpsc::channel(policy_config.channel_capacity);
152        // Memoized drainfile existence state. This is passed into delegation
153        // funtions to prevent them from continuing to populate blocklists
154        // if drain is set, as otherwise it will grow without bounds
155        // without the firewall running to periodically clear it.
156        let mem_drainfile_present = self
157            .fw_config
158            .as_ref()
159            .map(|config| config.drain_path.exists())
160            .unwrap_or(false);
161        self.metrics
162            .deadmans_switch_enabled
163            .set(mem_drainfile_present as i64);
164        let blocklists = match self.acl.clone() {
165            Acl::Blocklists(blocklists) => blocklists,
166            Acl::Allowlist(_) => fatal!("Allowlist ACL should not exist on spawn"),
167        };
168        let tally_loop_blocklists = blocklists.clone();
169        let clear_loop_blocklists = blocklists.clone();
170        let tally_loop_metrics = self.metrics.clone();
171        let clear_loop_metrics = self.metrics.clone();
172        let tally_loop_policy_config = policy_config.clone();
173        let tally_loop_fw_config = self.fw_config.clone();
174
175        let spam_policy = self
176            .spam_policy
177            .clone()
178            .expect("spam policy should exist on spawn");
179        let error_policy = self
180            .error_policy
181            .clone()
182            .expect("error policy should exist on spawn");
183        let spam_policy_clone = spam_policy.clone();
184        let error_policy_clone = error_policy.clone();
185
186        spawn_monitored_task!(run_tally_loop(
187            rx,
188            tally_loop_policy_config,
189            spam_policy_clone,
190            error_policy_clone,
191            tally_loop_fw_config,
192            tally_loop_blocklists,
193            tally_loop_metrics,
194            mem_drainfile_present,
195        ));
196        spawn_monitored_task!(run_clear_blocklists_loop(
197            clear_loop_blocklists,
198            clear_loop_metrics,
199        ));
200        self.open_tally_channel(tx);
201    }
202
203    pub async fn get_current_state(&self) -> TrafficControlReconfigParams {
204        let mut result = TrafficControlReconfigParams {
205            error_threshold: None,
206            spam_threshold: None,
207            dry_run: None,
208        };
209
210        if let Some(error_policy) = self.error_policy.as_ref()
211            && let TrafficControlPolicy::FreqThreshold(ref policy) = *error_policy.lock().await
212        {
213            result.error_threshold = Some(policy.client_threshold);
214        }
215
216        if let Some(spam_policy) = self.spam_policy.as_ref()
217            && let TrafficControlPolicy::FreqThreshold(ref policy) = *spam_policy.lock().await
218        {
219            result.spam_threshold = Some(policy.client_threshold);
220        }
221
222        result.dry_run = Some(self.policy_config.read().await.dry_run);
223        result
224    }
225
226    pub async fn admin_reconfigure(
227        &self,
228        params: TrafficControlReconfigParams,
229    ) -> Result<TrafficControlReconfigParams, SuiError> {
230        let TrafficControlReconfigParams {
231            error_threshold,
232            spam_threshold,
233            dry_run,
234        } = params;
235        if let Some(error_threshold) = error_threshold {
236            self.metrics
237                .error_client_threshold
238                .set(error_threshold as i64);
239            Self::update_policy_threshold(
240                self.error_policy.as_ref().unwrap(),
241                error_threshold,
242                dry_run,
243            )
244            .await?;
245        }
246        if let Some(spam_threshold) = spam_threshold {
247            self.metrics
248                .spam_client_threshold
249                .set(spam_threshold as i64);
250            Self::update_policy_threshold(
251                self.spam_policy.as_ref().unwrap(),
252                spam_threshold,
253                dry_run,
254            )
255            .await?;
256        }
257        if let Some(dry_run) = dry_run {
258            self.metrics.dry_run_enabled.set(dry_run as i64);
259            self.policy_config.write().await.dry_run = dry_run;
260        }
261
262        Ok(self.get_current_state().await)
263    }
264
265    async fn update_policy_threshold(
266        policy: &Arc<Mutex<TrafficControlPolicy>>,
267        threshold: u64,
268        dry_run: Option<bool>,
269    ) -> Result<(), SuiError> {
270        match *policy.lock().await {
271            TrafficControlPolicy::FreqThreshold(ref mut policy) => {
272                policy.client_threshold = threshold;
273                if let Some(dry_run) = dry_run {
274                    policy.config.dry_run = dry_run;
275                }
276                Ok(())
277            }
278            TrafficControlPolicy::TestNConnIP(ref mut policy) => {
279                policy.threshold = threshold;
280                if let Some(dry_run) = dry_run {
281                    policy.config.dry_run = dry_run;
282                }
283                Ok(())
284            }
285            _ => Err(SuiErrorKind::InvalidAdminRequest(
286                "Unsupported prior policy type during traffic control reconfiguration".to_string(),
287            )
288            .into()),
289        }
290    }
291
292    fn open_tally_channel(&self, tx: mpsc::Sender<TrafficTally>) {
293        self.tally_channel.lock().replace(tx);
294    }
295
296    fn set_policy_config_metrics(
297        policy_config: &PolicyConfig,
298        metrics: Arc<TrafficControllerMetrics>,
299    ) {
300        if let PolicyType::FreqThreshold(config) = &policy_config.spam_policy_type {
301            metrics
302                .spam_client_threshold
303                .set(config.client_threshold as i64);
304            metrics
305                .spam_proxied_client_threshold
306                .set(config.proxied_client_threshold as i64);
307        }
308        if let PolicyType::FreqThreshold(config) = &policy_config.error_policy_type {
309            metrics
310                .error_client_threshold
311                .set(config.client_threshold as i64);
312            metrics
313                .error_proxied_client_threshold
314                .set(config.proxied_client_threshold as i64);
315        }
316    }
317
318    pub fn tally(&self, tally: TrafficTally) {
319        if let Some(channel) = self.tally_channel.lock().as_ref() {
320            // Use try_send rather than send mainly to avoid creating backpressure
321            // on the caller if the channel is full, which may slow down the critical
322            // path. Dropping the tally on the floor should be ok, as in this case
323            // we are effectively sampling traffic, which we would need to do anyway
324            // if we are overloaded
325            match channel.try_send(tally) {
326                Err(TrySendError::Full(_)) => {
327                    warn!("TrafficController tally channel full, dropping tally");
328                    self.metrics.tally_channel_overflow.inc();
329                    // TODO: once we've verified this doesn't happen under normal
330                    // conditions, we can consider dropping the request itself given
331                    // that clearly the system is overloaded
332                }
333                Err(TrySendError::Closed(_)) => {
334                    warn!("TrafficController tally channel closed unexpectedly");
335                }
336                Ok(_) => {}
337            }
338        } else {
339            warn!("TrafficController not yet accepting tally requests.");
340        }
341    }
342
343    /// Handle check with dry-run mode considered
344    pub async fn check(&self, client: &Option<IpAddr>, proxied_client: &Option<IpAddr>) -> bool {
345        let policy_config = { self.policy_config.read().await.clone() };
346
347        let allowed = match &self.acl {
348            Acl::Allowlist(allowlist) => client.is_none() || allowlist.contains(&client.unwrap()),
349            Acl::Blocklists(blocklists) => {
350                self.check_blocklists(blocklists, client, proxied_client)
351                    .await
352            }
353        };
354
355        match (allowed, policy_config.dry_run) {
356            // request allowed
357            (true, _) => true,
358            // request blocked while in dry-run mode
359            (false, true) => {
360                debug!("Dry run mode: Blocked request from client {:?}", client);
361                self.record_blocked_request(client, true);
362                true
363            }
364            // request blocked
365            (false, false) => {
366                debug!("Blocked request from client {:?}", client);
367                self.record_blocked_request(client, false);
368                false
369            }
370        }
371    }
372
373    fn record_blocked_request(&self, client: &Option<IpAddr>, dry_run: bool) {
374        let dry_run_str = if dry_run { "true" } else { "false" };
375
376        // Hash IP to bucket to limit cardinality (100 buckets: bucket_0 through bucket_99)
377        let ip_label = if let Some(ip) = client {
378            let bucket = ip
379                .to_string()
380                .bytes()
381                .fold(0u8, |acc, b| acc.wrapping_add(b))
382                % 100;
383            let bucket_label = format!("bucket_{}", bucket);
384            trace!("IP {} maps to {}", ip, bucket_label);
385            bucket_label
386        } else {
387            "unknown".to_string()
388        };
389
390        self.metrics
391            .requests_blocked_at_protocol
392            .with_label_values(&[dry_run_str, &ip_label])
393            .inc();
394    }
395
396    /// Returns true if the connection is in blocklist, false otherwise
397    async fn check_blocklists(
398        &self,
399        blocklists: &Blocklists,
400        client: &Option<IpAddr>,
401        proxied_client: &Option<IpAddr>,
402    ) -> bool {
403        let client_check = self.check_and_clear_blocklist(
404            client,
405            blocklists.clients.clone(),
406            &self.metrics.connection_ip_blocklist_len,
407        );
408        let proxied_client_check = self.check_and_clear_blocklist(
409            proxied_client,
410            blocklists.proxied_clients.clone(),
411            &self.metrics.proxy_ip_blocklist_len,
412        );
413        let (client_check, proxied_client_check) =
414            futures::future::join(client_check, proxied_client_check).await;
415        client_check && proxied_client_check
416    }
417
418    async fn check_and_clear_blocklist(
419        &self,
420        client: &Option<IpAddr>,
421        blocklist: Blocklist,
422        blocklist_len_gauge: &IntGauge,
423    ) -> bool {
424        let client = match client {
425            Some(client) => client,
426            None => return true,
427        };
428        let now = SystemTime::now();
429        // the below two blocks cannot be nested, otherwise we will deadlock
430        // due to aquiring the lock on get, then holding across the remove
431        let (should_block, should_remove) = {
432            match blocklist.get(client) {
433                Some(expiration) if now >= *expiration => (false, true),
434                None => (false, false),
435                _ => (true, false),
436            }
437        };
438        if should_remove && blocklist.remove(client).is_some() {
439            blocklist_len_gauge.dec();
440        }
441        !should_block
442    }
443}
444
445/// Although we clear IPs from the blocklist lazily when they are checked,
446/// it's possible that over time we may accumulate a large number of stale
447/// IPs in the blocklist for clients that are added, then once blocked,
448/// never checked again. This function runs periodically to clear out any
449/// such stale IPs. This also ensures that the blocklist length metric
450/// accurately reflects TTL.
451async fn run_clear_blocklists_loop(blocklists: Blocklists, metrics: Arc<TrafficControllerMetrics>) {
452    loop {
453        tokio::time::sleep(Duration::from_secs(3)).await;
454        let now = SystemTime::now();
455        blocklists.clients.retain(|_, expiration| now < *expiration);
456        blocklists
457            .proxied_clients
458            .retain(|_, expiration| now < *expiration);
459        metrics
460            .connection_ip_blocklist_len
461            .set(blocklists.clients.len() as i64);
462        metrics
463            .proxy_ip_blocklist_len
464            .set(blocklists.proxied_clients.len() as i64);
465    }
466}
467
468async fn run_tally_loop(
469    mut receiver: mpsc::Receiver<TrafficTally>,
470    policy_config: PolicyConfig,
471    spam_policy: Arc<Mutex<TrafficControlPolicy>>,
472    error_policy: Arc<Mutex<TrafficControlPolicy>>,
473    fw_config: Option<RemoteFirewallConfig>,
474    blocklists: Blocklists,
475    metrics: Arc<TrafficControllerMetrics>,
476    mut mem_drainfile_present: bool,
477) {
478    let spam_blocklists = Arc::new(blocklists.clone());
479    let error_blocklists = Arc::new(blocklists);
480    let node_fw_client = fw_config
481        .as_ref()
482        .map(|fw_config| NodeFWClient::new(fw_config.remote_fw_url.clone()));
483
484    let timeout = fw_config
485        .as_ref()
486        .map(|fw_config| fw_config.drain_timeout_secs)
487        .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
488    let mut metric_timer = Instant::now();
489
490    loop {
491        tokio::select! {
492            received = receiver.recv() => {
493                match received {
494                    Some(tally) => {
495                        // Track tallies by method
496                        let method = tally.method.as_deref().unwrap_or("unknown");
497                        metrics.tallies
498                            .with_label_values(&[method])
499                            .inc();
500
501                        // TODO: spawn a task to handle tallying concurrently
502                        if let Err(err) = handle_spam_tally(
503                            spam_policy.clone(),
504                            &policy_config,
505                            &node_fw_client,
506                            &fw_config,
507                            tally.clone(),
508                            spam_blocklists.clone(),
509                            metrics.clone(),
510                            mem_drainfile_present,
511                        )
512                        .await {
513                            warn!("Error handling spam tally: {}", err);
514                        }
515                        if let Err(err) = handle_error_tally(
516                            error_policy.clone(),
517                            &policy_config,
518                            &node_fw_client,
519                            &fw_config,
520                            tally,
521                            error_blocklists.clone(),
522                            metrics.clone(),
523                            mem_drainfile_present,
524                        )
525                        .await {
526                            warn!("Error handling error tally: {}", err);
527                        }
528                    }
529                    None => {
530                        info!("TrafficController tally channel closed by all senders");
531                        return;
532                    }
533                }
534            }
535            // Dead man's switch - if we suspect something is sinking all traffic to node, disable nodefw
536            _ = tokio::time::sleep(tokio::time::Duration::from_secs(timeout)) => {
537                if let Some(fw_config) = &fw_config {
538                    error!("No traffic tallies received in {} seconds.", timeout);
539                    if mem_drainfile_present {
540                        continue;
541                    }
542                    if !fw_config.drain_path.exists() {
543                        mem_drainfile_present = true;
544                        warn!("Draining Node firewall.");
545                        File::create(&fw_config.drain_path)
546                            .expect("Failed to touch nodefw drain file");
547                        metrics.deadmans_switch_enabled.set(1);
548                    }
549                }
550            }
551        }
552
553        // every N seconds, we update metrics and logging that would be too
554        // spammy to be handled while processing each tally
555        if metric_timer.elapsed() > Duration::from_secs(METRICS_INTERVAL_SECS) {
556            if let TrafficControlPolicy::FreqThreshold(ref spam_policy) = *spam_policy.lock().await
557            {
558                if let Some(highest_direct_rate) = spam_policy.highest_direct_rate() {
559                    metrics
560                        .highest_direct_spam_rate
561                        .set(highest_direct_rate.0 as i64);
562                    debug!("Recent highest direct spam rate: {:?}", highest_direct_rate);
563                }
564                if let Some(highest_proxied_rate) = spam_policy.highest_proxied_rate() {
565                    metrics
566                        .highest_proxied_spam_rate
567                        .set(highest_proxied_rate.0 as i64);
568                    debug!(
569                        "Recent highest proxied spam rate: {:?}",
570                        highest_proxied_rate
571                    );
572                }
573            }
574            if let TrafficControlPolicy::FreqThreshold(ref error_policy) =
575                *error_policy.lock().await
576            {
577                if let Some(highest_direct_rate) = error_policy.highest_direct_rate() {
578                    metrics
579                        .highest_direct_error_rate
580                        .set(highest_direct_rate.0 as i64);
581                    debug!(
582                        "Recent highest direct error rate: {:?}",
583                        highest_direct_rate
584                    );
585                }
586                if let Some(highest_proxied_rate) = error_policy.highest_proxied_rate() {
587                    metrics
588                        .highest_proxied_error_rate
589                        .set(highest_proxied_rate.0 as i64);
590                    debug!(
591                        "Recent highest proxied error rate: {:?}",
592                        highest_proxied_rate
593                    );
594                }
595            }
596            metric_timer = Instant::now();
597        }
598    }
599}
600
601async fn handle_error_tally(
602    policy: Arc<Mutex<TrafficControlPolicy>>,
603    policy_config: &PolicyConfig,
604    nodefw_client: &Option<NodeFWClient>,
605    fw_config: &Option<RemoteFirewallConfig>,
606    tally: TrafficTally,
607    blocklists: Arc<Blocklists>,
608    metrics: Arc<TrafficControllerMetrics>,
609    mem_drainfile_present: bool,
610) -> Result<(), reqwest::Error> {
611    let Some((error_weight, error_type)) = tally.clone().error_info else {
612        return Ok(());
613    };
614    if !error_weight.is_sampled() {
615        return Ok(());
616    }
617    trace!(
618        "Handling error_type {:?} from client {:?}",
619        error_type, tally.direct,
620    );
621    metrics
622        .tally_error_types
623        .with_label_values(&[error_type.as_str()])
624        .inc();
625    let resp = policy.lock().await.handle_tally(tally);
626    metrics.error_tally_handled.inc();
627    if let Some(fw_config) = fw_config
628        && fw_config.delegate_error_blocking
629        && !mem_drainfile_present
630    {
631        let client = nodefw_client
632            .as_ref()
633            .expect("Expected NodeFWClient for blocklist delegation");
634        return delegate_policy_response(
635            resp,
636            policy_config,
637            client,
638            fw_config.destination_port,
639            metrics.clone(),
640        )
641        .await;
642    }
643    handle_policy_response(resp, policy_config, blocklists, metrics).await;
644    Ok(())
645}
646
647async fn handle_spam_tally(
648    policy: Arc<Mutex<TrafficControlPolicy>>,
649    policy_config: &PolicyConfig,
650    nodefw_client: &Option<NodeFWClient>,
651    fw_config: &Option<RemoteFirewallConfig>,
652    tally: TrafficTally,
653    blocklists: Arc<Blocklists>,
654    metrics: Arc<TrafficControllerMetrics>,
655    mem_drainfile_present: bool,
656) -> Result<(), reqwest::Error> {
657    if !(tally.spam_weight.is_sampled() && policy_config.spam_sample_rate.is_sampled()) {
658        return Ok(());
659    }
660    let resp = policy.lock().await.handle_tally(tally.clone());
661    metrics.tally_handled.inc();
662    if let Some(fw_config) = fw_config
663        && fw_config.delegate_spam_blocking
664        && !mem_drainfile_present
665    {
666        let client = nodefw_client
667            .as_ref()
668            .expect("Expected NodeFWClient for blocklist delegation");
669        return delegate_policy_response(
670            resp,
671            policy_config,
672            client,
673            fw_config.destination_port,
674            metrics.clone(),
675        )
676        .await;
677    }
678    handle_policy_response(resp, policy_config, blocklists, metrics).await;
679    Ok(())
680}
681
682async fn handle_policy_response(
683    response: PolicyResponse,
684    policy_config: &PolicyConfig,
685    blocklists: Arc<Blocklists>,
686    metrics: Arc<TrafficControllerMetrics>,
687) {
688    let PolicyResponse {
689        block_client,
690        block_proxied_client,
691    } = response;
692    let PolicyConfig {
693        connection_blocklist_ttl_sec,
694        proxy_blocklist_ttl_sec,
695        ..
696    } = policy_config;
697    if let Some(client) = block_client
698        && blocklists
699            .clients
700            .insert(
701                client,
702                SystemTime::now() + Duration::from_secs(*connection_blocklist_ttl_sec),
703            )
704            .is_none()
705    {
706        // Only increment the metric if the client was not already blocked
707        debug!("Adding client {:?} to blocklist", client);
708        metrics.connection_ip_blocklist_len.inc();
709    }
710    if let Some(client) = block_proxied_client
711        && blocklists
712            .proxied_clients
713            .insert(
714                client,
715                SystemTime::now() + Duration::from_secs(*proxy_blocklist_ttl_sec),
716            )
717            .is_none()
718    {
719        // Only increment the metric if the client was not already blocked
720        debug!("Adding proxied client {:?} to blocklist", client);
721        metrics.proxy_ip_blocklist_len.inc();
722    }
723}
724
725async fn delegate_policy_response(
726    response: PolicyResponse,
727    policy_config: &PolicyConfig,
728    node_fw_client: &NodeFWClient,
729    destination_port: u16,
730    metrics: Arc<TrafficControllerMetrics>,
731) -> Result<(), reqwest::Error> {
732    let PolicyResponse {
733        block_client,
734        block_proxied_client,
735    } = response;
736    let PolicyConfig {
737        connection_blocklist_ttl_sec,
738        proxy_blocklist_ttl_sec,
739        ..
740    } = policy_config;
741    let mut addresses = vec![];
742    if let Some(client_id) = block_client {
743        debug!("Delegating client blocking to firewall");
744        addresses.push(BlockAddress {
745            source_address: client_id.to_string(),
746            destination_port,
747            ttl: *connection_blocklist_ttl_sec,
748        });
749    }
750    if let Some(ip) = block_proxied_client {
751        debug!("Delegating proxied client blocking to firewall");
752        addresses.push(BlockAddress {
753            source_address: ip.to_string(),
754            destination_port,
755            ttl: *proxy_blocklist_ttl_sec,
756        });
757    }
758    if addresses.is_empty() {
759        Ok(())
760    } else {
761        metrics
762            .blocks_delegated_to_firewall
763            .inc_by(addresses.len() as u64);
764        match node_fw_client
765            .block_addresses(BlockAddresses { addresses })
766            .await
767        {
768            Ok(()) => Ok(()),
769            Err(err) => {
770                metrics.firewall_delegation_request_fail.inc();
771                Err(err)
772            }
773        }
774    }
775}
776
777#[derive(Debug, Clone)]
778pub struct TrafficSimMetrics {
779    pub num_requests: u64,
780    pub num_blocked: u64,
781    pub time_to_first_block: Option<Duration>,
782    pub abs_time_to_first_block: Option<Duration>,
783    pub total_time_blocked: Duration,
784    pub num_blocklist_adds: u64,
785}
786
787impl Default for TrafficSimMetrics {
788    fn default() -> Self {
789        Self {
790            num_requests: 0,
791            num_blocked: 0,
792            time_to_first_block: None,
793            abs_time_to_first_block: None,
794            total_time_blocked: Duration::from_micros(0),
795            num_blocklist_adds: 0,
796        }
797    }
798}
799
800impl Add for TrafficSimMetrics {
801    type Output = Self;
802
803    fn add(self, other: Self) -> Self {
804        Self {
805            num_requests: self.num_requests + other.num_requests,
806            num_blocked: self.num_blocked + other.num_blocked,
807            time_to_first_block: match (self.time_to_first_block, other.time_to_first_block) {
808                (Some(a), Some(b)) => Some(a + b),
809                (Some(a), None) => Some(a),
810                (None, Some(b)) => Some(b),
811                (None, None) => None,
812            },
813            abs_time_to_first_block: match (
814                self.abs_time_to_first_block,
815                other.abs_time_to_first_block,
816            ) {
817                (Some(a), Some(b)) => Some(a.min(b)),
818                (Some(a), None) => Some(a),
819                (None, Some(b)) => Some(b),
820                (None, None) => None,
821            },
822            total_time_blocked: self.total_time_blocked + other.total_time_blocked,
823            num_blocklist_adds: self.num_blocklist_adds + other.num_blocklist_adds,
824        }
825    }
826}
827
828pub struct TrafficSim {
829    pub traffic_controller: TrafficController,
830}
831
832impl TrafficSim {
833    pub async fn run(
834        policy: PolicyConfig,
835        num_clients: u8,
836        per_client_tps: usize,
837        duration: Duration,
838        report: bool,
839    ) -> TrafficSimMetrics {
840        assert!(
841            per_client_tps <= 10_000,
842            "per_client_tps must be less than 10,000. For higher values, increase num_clients"
843        );
844        assert!(num_clients < 20, "num_clients must be greater than 0");
845        assert!(num_clients > 0);
846        assert!(per_client_tps > 0);
847        assert!(duration.as_secs() > 0);
848
849        let controller = TrafficController::init_for_test(policy.clone(), None).await;
850        let tasks = (0..num_clients).map(|task_num| {
851            tokio::spawn(Self::run_single_client(
852                controller.clone(),
853                duration,
854                task_num,
855                per_client_tps,
856            ))
857        });
858
859        let status_task = if report {
860            Some(tokio::spawn(async move {
861                println!(
862                    "Running naive traffic simulation for {} seconds",
863                    duration.as_secs()
864                );
865                println!("Policy: {:#?}", policy);
866                println!("Num clients: {}", num_clients);
867                println!("TPS per client: {}", per_client_tps);
868                println!(
869                    "Target total TPS: {}",
870                    per_client_tps * num_clients as usize
871                );
872                println!("\n");
873                for _ in 0..duration.as_secs() {
874                    print!(".");
875                    tokio::time::sleep(Duration::from_secs(1)).await;
876                }
877                println!();
878            }))
879        } else {
880            None
881        };
882
883        let metrics = futures::future::join_all(tasks).await.into_iter().fold(
884            TrafficSimMetrics::default(),
885            |acc, run_client_ret| {
886                if let Ok(metrics) = run_client_ret {
887                    acc + metrics
888                } else {
889                    error!(
890                        "Error running traffic sim client: {:?}",
891                        run_client_ret.err()
892                    );
893                    acc
894                }
895            },
896        );
897
898        if report {
899            status_task.unwrap().await.unwrap();
900            Self::report_metrics(metrics.clone(), duration, per_client_tps, num_clients);
901        }
902        metrics
903    }
904
905    async fn run_single_client(
906        controller: TrafficController,
907        duration: Duration,
908        task_num: u8,
909        per_client_tps: usize,
910    ) -> TrafficSimMetrics {
911        // Do an initial sleep for a random amount of time to smooth
912        // out the traffic. This shouldn't be strictly necessary and
913        // we can remove if we want more determinism
914        let sleep_time = Duration::from_micros(rand::thread_rng().gen_range(0..100));
915        tokio::time::sleep(sleep_time).await;
916
917        // collectors
918        let mut num_requests = 0;
919        let mut num_blocked = 0;
920        let mut time_to_first_block = None;
921        let mut total_time_blocked = Duration::from_micros(0);
922        let mut num_blocklist_adds = 0;
923        // state variables
924        let mut currently_blocked = false;
925        let mut time_blocked_start = Instant::now();
926        let start = Instant::now();
927
928        while start.elapsed() < duration {
929            let client = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num)));
930            let allowed = controller.check(&client, &None).await;
931            if allowed {
932                if currently_blocked {
933                    total_time_blocked += time_blocked_start.elapsed();
934                    currently_blocked = false;
935                }
936                controller.tally(TrafficTally::new(
937                    client,
938                    // TODO add proxy IP for testing
939                    None,
940                    // TODO add weight adjustments
941                    None,
942                    Weight::one(),
943                ));
944            } else {
945                if !currently_blocked {
946                    time_blocked_start = Instant::now();
947                    currently_blocked = true;
948                    num_blocklist_adds += 1;
949                    if time_to_first_block.is_none() {
950                        time_to_first_block = Some(start.elapsed());
951                    }
952                }
953                num_blocked += 1;
954            }
955            num_requests += 1;
956            tokio::time::sleep(Duration::from_micros(1_000_000 / per_client_tps as u64)).await;
957        }
958        TrafficSimMetrics {
959            num_requests,
960            num_blocked,
961            time_to_first_block,
962            abs_time_to_first_block: time_to_first_block,
963            total_time_blocked,
964            num_blocklist_adds,
965        }
966    }
967
968    fn report_metrics(
969        metrics: TrafficSimMetrics,
970        duration: Duration,
971        per_client_tps: usize,
972        num_clients: u8,
973    ) {
974        println!("TrafficSim metrics:");
975        println!("-------------------");
976        // The below two should be near equal
977        println!(
978            "Num expected requests: {}",
979            per_client_tps * (num_clients as usize) * duration.as_secs() as usize
980        );
981        println!("Num actual requests: {}", metrics.num_requests);
982        // This reflects the number of requests that were blocked, but note that once a client
983        // is added to the blocklist, all subsequent requests from that client are blocked
984        // until ttl is expired.
985        println!("Num blocked requests: {}", metrics.num_blocked);
986        // This metric on the other hand reflects the number of times a client was added to the blocklist
987        // and thus can be compared with the expectation based on the policy block threshold and ttl
988        println!(
989            "Num times added to blocklist: {}",
990            metrics.num_blocklist_adds
991        );
992        // This averages the duration for the first request to be blocked across all clients,
993        // which is useful for understanding if the policy is rate limiting based on expectation
994        let avg_first_block_time = metrics
995            .time_to_first_block
996            .map(|ttf| ttf / num_clients as u32);
997        println!("Average time to first block: {:?}", avg_first_block_time);
998        // This is the time it took for the first request to be blocked across all clients,
999        // and is instead more useful for understanding false positives in terms of rate and magnitude.
1000        println!(
1001            "Abolute time to first block (across all clients): {:?}",
1002            metrics.abs_time_to_first_block
1003        );
1004        // Useful for ensuring that TTL is respected
1005        let avg_time_blocked = if metrics.num_blocklist_adds > 0 {
1006            metrics.total_time_blocked.as_millis() as u64 / metrics.num_blocklist_adds
1007        } else {
1008            0
1009        };
1010        println!(
1011            "Average time blocked (ttl): {:?}",
1012            Duration::from_millis(avg_time_blocked)
1013        );
1014    }
1015}
1016
1017pub fn parse_ip(ip: &str) -> Option<IpAddr> {
1018    ip.parse::<IpAddr>().ok().or_else(|| {
1019        ip.parse::<SocketAddr>()
1020            .ok()
1021            .map(|socket_addr| socket_addr.ip())
1022            .or_else(|| {
1023                error!("Failed to parse value of {:?} to ip address or socket.", ip,);
1024                None
1025            })
1026    })
1027}