sui_core/traffic_controller/
mod.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4pub mod metrics;
5pub mod nodefw_client;
6pub mod nodefw_test_server;
7pub mod policies;
8
9use dashmap::DashMap;
10use fs::File;
11use mysten_common::fatal;
12use prometheus::IntGauge;
13use std::fs;
14use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15use std::ops::Add;
16use std::sync::Arc;
17use sui_types::error::{SuiError, SuiErrorKind};
18
19use self::metrics::TrafficControllerMetrics;
20use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses, NodeFWClient};
21use crate::traffic_controller::policies::{
22    Policy, PolicyResponse, TrafficControlPolicy, TrafficTally,
23};
24use mysten_metrics::spawn_monitored_task;
25use parking_lot::Mutex as ParkingLotMutex;
26use rand::Rng;
27use std::fmt::Debug;
28use std::time::{Duration, Instant, SystemTime};
29use sui_types::traffic_control::{
30    PolicyConfig, PolicyType, RemoteFirewallConfig, TrafficControlReconfigParams, Weight,
31};
32use tokio::sync::mpsc::error::TrySendError;
33use tokio::sync::{Mutex, RwLock, mpsc};
34use tracing::{debug, error, info, trace, warn};
35
36pub const METRICS_INTERVAL_SECS: u64 = 2;
37pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 300;
38
39type Blocklist = Arc<DashMap<IpAddr, SystemTime>>;
40
41#[derive(Clone)]
42pub struct Blocklists {
43    clients: Blocklist,
44    proxied_clients: Blocklist,
45}
46
47#[derive(Clone)]
48pub enum Acl {
49    Blocklists(Blocklists),
50    /// If this variant is set, then we do no tallying or running
51    /// of background tasks, and instead simply block all IPs not
52    /// in the allowlist on calls to `check`. The allowlist should
53    /// only be populated once at initialization.
54    Allowlist(Vec<IpAddr>),
55}
56
57#[derive(Clone)]
58pub struct TrafficController {
59    tally_channel: Arc<ParkingLotMutex<Option<mpsc::Sender<TrafficTally>>>>,
60    acl: Acl,
61    metrics: Arc<TrafficControllerMetrics>,
62    spam_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
63    error_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
64    policy_config: Arc<RwLock<PolicyConfig>>,
65    fw_config: Option<RemoteFirewallConfig>,
66}
67
68impl Debug for TrafficController {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        // NOTE: we do not want to print the contents of the blocklists to logs
71        // given that (1) it contains all requests IPs, and (2) it could be quite
72        // large. Instead, we print lengths of the blocklists. Further, we prefer
73        // to get length from the metrics rather than from the blocklists themselves
74        // to avoid unneccesarily aquiring the read lock.
75        f.debug_struct("TrafficController")
76            .field(
77                "connection_ip_blocklist_len",
78                &self.metrics.connection_ip_blocklist_len.get(),
79            )
80            .field(
81                "proxy_ip_blocklist_len",
82                &self.metrics.proxy_ip_blocklist_len.get(),
83            )
84            .finish()
85    }
86}
87
88impl TrafficController {
89    pub async fn init(
90        policy_config: PolicyConfig,
91        metrics: Arc<TrafficControllerMetrics>,
92        fw_config: Option<RemoteFirewallConfig>,
93    ) -> Self {
94        metrics.dry_run_enabled.set(policy_config.dry_run as i64);
95        match policy_config.allow_list.clone() {
96            Some(allow_list) => {
97                let allowlist = allow_list
98                    .into_iter()
99                    .map(|ip_str| {
100                        parse_ip(&ip_str).unwrap_or_else(|| {
101                            fatal!("Failed to parse allowlist IP address: {:?}", ip_str)
102                        })
103                    })
104                    .collect();
105                Self {
106                    tally_channel: Arc::new(ParkingLotMutex::new(None)),
107                    acl: Acl::Allowlist(allowlist),
108                    metrics,
109                    policy_config: Arc::new(RwLock::new(policy_config)),
110                    fw_config,
111                    spam_policy: None,
112                    error_policy: None,
113                }
114            }
115            None => {
116                let spam_policy = Arc::new(Mutex::new(
117                    TrafficControlPolicy::from_spam_config(policy_config.clone()).await,
118                ));
119                let error_policy = Arc::new(Mutex::new(
120                    TrafficControlPolicy::from_error_config(policy_config.clone()).await,
121                ));
122                let this = Self {
123                    tally_channel: Arc::new(ParkingLotMutex::new(None)),
124                    acl: Acl::Blocklists(Blocklists {
125                        clients: Arc::new(DashMap::new()),
126                        proxied_clients: Arc::new(DashMap::new()),
127                    }),
128                    metrics,
129                    policy_config: Arc::new(RwLock::new(policy_config)),
130                    fw_config,
131                    spam_policy: Some(spam_policy),
132                    error_policy: Some(error_policy),
133                };
134                this.spawn().await;
135                this
136            }
137        }
138    }
139
140    pub async fn init_for_test(
141        policy_config: PolicyConfig,
142        fw_config: Option<RemoteFirewallConfig>,
143    ) -> Self {
144        let metrics = Arc::new(TrafficControllerMetrics::new(&prometheus::Registry::new()));
145        Self::init(policy_config, metrics, fw_config).await
146    }
147
148    async fn spawn(&self) {
149        let policy_config = { self.policy_config.read().await.clone() };
150        Self::set_policy_config_metrics(&policy_config, self.metrics.clone());
151        let (tx, rx) = mpsc::channel(policy_config.channel_capacity);
152        // Memoized drainfile existence state. This is passed into delegation
153        // funtions to prevent them from continuing to populate blocklists
154        // if drain is set, as otherwise it will grow without bounds
155        // without the firewall running to periodically clear it.
156        let mem_drainfile_present = self
157            .fw_config
158            .as_ref()
159            .map(|config| config.drain_path.exists())
160            .unwrap_or(false);
161        self.metrics
162            .deadmans_switch_enabled
163            .set(mem_drainfile_present as i64);
164        let blocklists = match self.acl.clone() {
165            Acl::Blocklists(blocklists) => blocklists,
166            Acl::Allowlist(_) => fatal!("Allowlist ACL should not exist on spawn"),
167        };
168        let tally_loop_blocklists = blocklists.clone();
169        let clear_loop_blocklists = blocklists.clone();
170        let tally_loop_metrics = self.metrics.clone();
171        let clear_loop_metrics = self.metrics.clone();
172        let tally_loop_policy_config = policy_config.clone();
173        let tally_loop_fw_config = self.fw_config.clone();
174
175        let spam_policy = self
176            .spam_policy
177            .clone()
178            .expect("spam policy should exist on spawn");
179        let error_policy = self
180            .error_policy
181            .clone()
182            .expect("error policy should exist on spawn");
183        let spam_policy_clone = spam_policy.clone();
184        let error_policy_clone = error_policy.clone();
185
186        spawn_monitored_task!(run_tally_loop(
187            rx,
188            tally_loop_policy_config,
189            spam_policy_clone,
190            error_policy_clone,
191            tally_loop_fw_config,
192            tally_loop_blocklists,
193            tally_loop_metrics,
194            mem_drainfile_present,
195        ));
196        spawn_monitored_task!(run_clear_blocklists_loop(
197            clear_loop_blocklists,
198            clear_loop_metrics,
199        ));
200        self.open_tally_channel(tx);
201    }
202
203    pub async fn get_current_state(&self) -> TrafficControlReconfigParams {
204        let mut result = TrafficControlReconfigParams {
205            error_threshold: None,
206            spam_threshold: None,
207            dry_run: None,
208        };
209
210        if let Some(error_policy) = self.error_policy.as_ref()
211            && let TrafficControlPolicy::FreqThreshold(ref policy) = *error_policy.lock().await
212        {
213            result.error_threshold = Some(policy.client_threshold);
214        }
215
216        if let Some(spam_policy) = self.spam_policy.as_ref()
217            && let TrafficControlPolicy::FreqThreshold(ref policy) = *spam_policy.lock().await
218        {
219            result.spam_threshold = Some(policy.client_threshold);
220        }
221
222        result.dry_run = Some(self.policy_config.read().await.dry_run);
223        result
224    }
225
226    pub async fn admin_reconfigure(
227        &self,
228        params: TrafficControlReconfigParams,
229    ) -> Result<TrafficControlReconfigParams, SuiError> {
230        let TrafficControlReconfigParams {
231            error_threshold,
232            spam_threshold,
233            dry_run,
234        } = params;
235        if let Some(error_threshold) = error_threshold {
236            self.metrics
237                .error_client_threshold
238                .set(error_threshold as i64);
239            Self::update_policy_threshold(
240                self.error_policy.as_ref().unwrap(),
241                error_threshold,
242                dry_run,
243            )
244            .await?;
245        }
246        if let Some(spam_threshold) = spam_threshold {
247            self.metrics
248                .spam_client_threshold
249                .set(spam_threshold as i64);
250            Self::update_policy_threshold(
251                self.spam_policy.as_ref().unwrap(),
252                spam_threshold,
253                dry_run,
254            )
255            .await?;
256        }
257        if let Some(dry_run) = dry_run {
258            self.metrics.dry_run_enabled.set(dry_run as i64);
259            self.policy_config.write().await.dry_run = dry_run;
260        }
261
262        Ok(self.get_current_state().await)
263    }
264
265    async fn update_policy_threshold(
266        policy: &Arc<Mutex<TrafficControlPolicy>>,
267        threshold: u64,
268        dry_run: Option<bool>,
269    ) -> Result<(), SuiError> {
270        match *policy.lock().await {
271            TrafficControlPolicy::FreqThreshold(ref mut policy) => {
272                policy.client_threshold = threshold;
273                if let Some(dry_run) = dry_run {
274                    policy.config.dry_run = dry_run;
275                }
276                Ok(())
277            }
278            TrafficControlPolicy::TestNConnIP(ref mut policy) => {
279                policy.threshold = threshold;
280                if let Some(dry_run) = dry_run {
281                    policy.config.dry_run = dry_run;
282                }
283                Ok(())
284            }
285            _ => Err(SuiErrorKind::InvalidAdminRequest(
286                "Unsupported prior policy type during traffic control reconfiguration".to_string(),
287            )
288            .into()),
289        }
290    }
291
292    fn open_tally_channel(&self, tx: mpsc::Sender<TrafficTally>) {
293        self.tally_channel.lock().replace(tx);
294    }
295
296    fn set_policy_config_metrics(
297        policy_config: &PolicyConfig,
298        metrics: Arc<TrafficControllerMetrics>,
299    ) {
300        if let PolicyType::FreqThreshold(config) = &policy_config.spam_policy_type {
301            metrics
302                .spam_client_threshold
303                .set(config.client_threshold as i64);
304            metrics
305                .spam_proxied_client_threshold
306                .set(config.proxied_client_threshold as i64);
307        }
308        if let PolicyType::FreqThreshold(config) = &policy_config.error_policy_type {
309            metrics
310                .error_client_threshold
311                .set(config.client_threshold as i64);
312            metrics
313                .error_proxied_client_threshold
314                .set(config.proxied_client_threshold as i64);
315        }
316    }
317
318    pub fn tally(&self, tally: TrafficTally) {
319        if let Some(channel) = self.tally_channel.lock().as_ref() {
320            // Use try_send rather than send mainly to avoid creating backpressure
321            // on the caller if the channel is full, which may slow down the critical
322            // path. Dropping the tally on the floor should be ok, as in this case
323            // we are effectively sampling traffic, which we would need to do anyway
324            // if we are overloaded
325            match channel.try_send(tally) {
326                Err(TrySendError::Full(_)) => {
327                    warn!("TrafficController tally channel full, dropping tally");
328                    self.metrics.tally_channel_overflow.inc();
329                    // TODO: once we've verified this doesn't happen under normal
330                    // conditions, we can consider dropping the request itself given
331                    // that clearly the system is overloaded
332                }
333                Err(TrySendError::Closed(_)) => {
334                    warn!("TrafficController tally channel closed unexpectedly");
335                }
336                Ok(_) => {}
337            }
338        } else {
339            warn!("TrafficController not yet accepting tally requests.");
340        }
341    }
342
343    /// Handle check with dry-run mode considered
344    pub async fn check(&self, client: &Option<IpAddr>, proxied_client: &Option<IpAddr>) -> bool {
345        let policy_config = { self.policy_config.read().await.clone() };
346        let check_with_dry_run_maybe = |allowed| -> bool {
347            match (allowed, policy_config.dry_run) {
348                // request allowed
349                (true, _) => true,
350                // request blocked while in dry-run mode
351                (false, true) => {
352                    debug!("Dry run mode: Blocked request from client {:?}", client);
353                    self.metrics.num_dry_run_blocked_requests.inc();
354                    true
355                }
356                // request blocked
357                (false, false) => {
358                    debug!("Blocked request from client {:?}", client);
359                    self.metrics.requests_blocked_at_protocol.inc();
360                    false
361                }
362            }
363        };
364
365        match &self.acl {
366            Acl::Allowlist(allowlist) => {
367                let allowed = client.is_none() || allowlist.contains(&client.unwrap());
368                check_with_dry_run_maybe(allowed)
369            }
370            Acl::Blocklists(blocklists) => {
371                let allowed = self
372                    .check_blocklists(blocklists, client, proxied_client)
373                    .await;
374                check_with_dry_run_maybe(allowed)
375            }
376        }
377    }
378
379    /// Returns true if the connection is in blocklist, false otherwise
380    async fn check_blocklists(
381        &self,
382        blocklists: &Blocklists,
383        client: &Option<IpAddr>,
384        proxied_client: &Option<IpAddr>,
385    ) -> bool {
386        let client_check = self.check_and_clear_blocklist(
387            client,
388            blocklists.clients.clone(),
389            &self.metrics.connection_ip_blocklist_len,
390        );
391        let proxied_client_check = self.check_and_clear_blocklist(
392            proxied_client,
393            blocklists.proxied_clients.clone(),
394            &self.metrics.proxy_ip_blocklist_len,
395        );
396        let (client_check, proxied_client_check) =
397            futures::future::join(client_check, proxied_client_check).await;
398        client_check && proxied_client_check
399    }
400
401    async fn check_and_clear_blocklist(
402        &self,
403        client: &Option<IpAddr>,
404        blocklist: Blocklist,
405        blocklist_len_gauge: &IntGauge,
406    ) -> bool {
407        let client = match client {
408            Some(client) => client,
409            None => return true,
410        };
411        let now = SystemTime::now();
412        // the below two blocks cannot be nested, otherwise we will deadlock
413        // due to aquiring the lock on get, then holding across the remove
414        let (should_block, should_remove) = {
415            match blocklist.get(client) {
416                Some(expiration) if now >= *expiration => (false, true),
417                None => (false, false),
418                _ => (true, false),
419            }
420        };
421        if should_remove && blocklist.remove(client).is_some() {
422            blocklist_len_gauge.dec();
423        }
424        !should_block
425    }
426}
427
428/// Although we clear IPs from the blocklist lazily when they are checked,
429/// it's possible that over time we may accumulate a large number of stale
430/// IPs in the blocklist for clients that are added, then once blocked,
431/// never checked again. This function runs periodically to clear out any
432/// such stale IPs. This also ensures that the blocklist length metric
433/// accurately reflects TTL.
434async fn run_clear_blocklists_loop(blocklists: Blocklists, metrics: Arc<TrafficControllerMetrics>) {
435    loop {
436        tokio::time::sleep(Duration::from_secs(3)).await;
437        let now = SystemTime::now();
438        blocklists.clients.retain(|_, expiration| now < *expiration);
439        blocklists
440            .proxied_clients
441            .retain(|_, expiration| now < *expiration);
442        metrics
443            .connection_ip_blocklist_len
444            .set(blocklists.clients.len() as i64);
445        metrics
446            .proxy_ip_blocklist_len
447            .set(blocklists.proxied_clients.len() as i64);
448    }
449}
450
451async fn run_tally_loop(
452    mut receiver: mpsc::Receiver<TrafficTally>,
453    policy_config: PolicyConfig,
454    spam_policy: Arc<Mutex<TrafficControlPolicy>>,
455    error_policy: Arc<Mutex<TrafficControlPolicy>>,
456    fw_config: Option<RemoteFirewallConfig>,
457    blocklists: Blocklists,
458    metrics: Arc<TrafficControllerMetrics>,
459    mut mem_drainfile_present: bool,
460) {
461    let spam_blocklists = Arc::new(blocklists.clone());
462    let error_blocklists = Arc::new(blocklists);
463    let node_fw_client = fw_config
464        .as_ref()
465        .map(|fw_config| NodeFWClient::new(fw_config.remote_fw_url.clone()));
466
467    let timeout = fw_config
468        .as_ref()
469        .map(|fw_config| fw_config.drain_timeout_secs)
470        .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
471    let mut metric_timer = Instant::now();
472
473    loop {
474        tokio::select! {
475            received = receiver.recv() => {
476                metrics.tallies.inc();
477                match received {
478                    Some(tally) => {
479                        // TODO: spawn a task to handle tallying concurrently
480                        if let Err(err) = handle_spam_tally(
481                            spam_policy.clone(),
482                            &policy_config,
483                            &node_fw_client,
484                            &fw_config,
485                            tally.clone(),
486                            spam_blocklists.clone(),
487                            metrics.clone(),
488                            mem_drainfile_present,
489                        )
490                        .await {
491                            warn!("Error handling spam tally: {}", err);
492                        }
493                        if let Err(err) = handle_error_tally(
494                            error_policy.clone(),
495                            &policy_config,
496                            &node_fw_client,
497                            &fw_config,
498                            tally,
499                            error_blocklists.clone(),
500                            metrics.clone(),
501                            mem_drainfile_present,
502                        )
503                        .await {
504                            warn!("Error handling error tally: {}", err);
505                        }
506                    }
507                    None => {
508                        info!("TrafficController tally channel closed by all senders");
509                        return;
510                    }
511                }
512            }
513            // Dead man's switch - if we suspect something is sinking all traffic to node, disable nodefw
514            _ = tokio::time::sleep(tokio::time::Duration::from_secs(timeout)) => {
515                if let Some(fw_config) = &fw_config {
516                    error!("No traffic tallies received in {} seconds.", timeout);
517                    if mem_drainfile_present {
518                        continue;
519                    }
520                    if !fw_config.drain_path.exists() {
521                        mem_drainfile_present = true;
522                        warn!("Draining Node firewall.");
523                        File::create(&fw_config.drain_path)
524                            .expect("Failed to touch nodefw drain file");
525                        metrics.deadmans_switch_enabled.set(1);
526                    }
527                }
528            }
529        }
530
531        // every N seconds, we update metrics and logging that would be too
532        // spammy to be handled while processing each tally
533        if metric_timer.elapsed() > Duration::from_secs(METRICS_INTERVAL_SECS) {
534            if let TrafficControlPolicy::FreqThreshold(ref spam_policy) = *spam_policy.lock().await
535            {
536                if let Some(highest_direct_rate) = spam_policy.highest_direct_rate() {
537                    metrics
538                        .highest_direct_spam_rate
539                        .set(highest_direct_rate.0 as i64);
540                    debug!("Recent highest direct spam rate: {:?}", highest_direct_rate);
541                }
542                if let Some(highest_proxied_rate) = spam_policy.highest_proxied_rate() {
543                    metrics
544                        .highest_proxied_spam_rate
545                        .set(highest_proxied_rate.0 as i64);
546                    debug!(
547                        "Recent highest proxied spam rate: {:?}",
548                        highest_proxied_rate
549                    );
550                }
551            }
552            if let TrafficControlPolicy::FreqThreshold(ref error_policy) =
553                *error_policy.lock().await
554            {
555                if let Some(highest_direct_rate) = error_policy.highest_direct_rate() {
556                    metrics
557                        .highest_direct_error_rate
558                        .set(highest_direct_rate.0 as i64);
559                    debug!(
560                        "Recent highest direct error rate: {:?}",
561                        highest_direct_rate
562                    );
563                }
564                if let Some(highest_proxied_rate) = error_policy.highest_proxied_rate() {
565                    metrics
566                        .highest_proxied_error_rate
567                        .set(highest_proxied_rate.0 as i64);
568                    debug!(
569                        "Recent highest proxied error rate: {:?}",
570                        highest_proxied_rate
571                    );
572                }
573            }
574            metric_timer = Instant::now();
575        }
576    }
577}
578
579async fn handle_error_tally(
580    policy: Arc<Mutex<TrafficControlPolicy>>,
581    policy_config: &PolicyConfig,
582    nodefw_client: &Option<NodeFWClient>,
583    fw_config: &Option<RemoteFirewallConfig>,
584    tally: TrafficTally,
585    blocklists: Arc<Blocklists>,
586    metrics: Arc<TrafficControllerMetrics>,
587    mem_drainfile_present: bool,
588) -> Result<(), reqwest::Error> {
589    let Some((error_weight, error_type)) = tally.clone().error_info else {
590        return Ok(());
591    };
592    if !error_weight.is_sampled() {
593        return Ok(());
594    }
595    trace!(
596        "Handling error_type {:?} from client {:?}",
597        error_type, tally.direct,
598    );
599    metrics
600        .tally_error_types
601        .with_label_values(&[error_type.as_str()])
602        .inc();
603    let resp = policy.lock().await.handle_tally(tally);
604    metrics.error_tally_handled.inc();
605    if let Some(fw_config) = fw_config
606        && fw_config.delegate_error_blocking
607        && !mem_drainfile_present
608    {
609        let client = nodefw_client
610            .as_ref()
611            .expect("Expected NodeFWClient for blocklist delegation");
612        return delegate_policy_response(
613            resp,
614            policy_config,
615            client,
616            fw_config.destination_port,
617            metrics.clone(),
618        )
619        .await;
620    }
621    handle_policy_response(resp, policy_config, blocklists, metrics).await;
622    Ok(())
623}
624
625async fn handle_spam_tally(
626    policy: Arc<Mutex<TrafficControlPolicy>>,
627    policy_config: &PolicyConfig,
628    nodefw_client: &Option<NodeFWClient>,
629    fw_config: &Option<RemoteFirewallConfig>,
630    tally: TrafficTally,
631    blocklists: Arc<Blocklists>,
632    metrics: Arc<TrafficControllerMetrics>,
633    mem_drainfile_present: bool,
634) -> Result<(), reqwest::Error> {
635    if !(tally.spam_weight.is_sampled() && policy_config.spam_sample_rate.is_sampled()) {
636        return Ok(());
637    }
638    let resp = policy.lock().await.handle_tally(tally.clone());
639    metrics.tally_handled.inc();
640    if let Some(fw_config) = fw_config
641        && fw_config.delegate_spam_blocking
642        && !mem_drainfile_present
643    {
644        let client = nodefw_client
645            .as_ref()
646            .expect("Expected NodeFWClient for blocklist delegation");
647        return delegate_policy_response(
648            resp,
649            policy_config,
650            client,
651            fw_config.destination_port,
652            metrics.clone(),
653        )
654        .await;
655    }
656    handle_policy_response(resp, policy_config, blocklists, metrics).await;
657    Ok(())
658}
659
660async fn handle_policy_response(
661    response: PolicyResponse,
662    policy_config: &PolicyConfig,
663    blocklists: Arc<Blocklists>,
664    metrics: Arc<TrafficControllerMetrics>,
665) {
666    let PolicyResponse {
667        block_client,
668        block_proxied_client,
669    } = response;
670    let PolicyConfig {
671        connection_blocklist_ttl_sec,
672        proxy_blocklist_ttl_sec,
673        ..
674    } = policy_config;
675    if let Some(client) = block_client
676        && blocklists
677            .clients
678            .insert(
679                client,
680                SystemTime::now() + Duration::from_secs(*connection_blocklist_ttl_sec),
681            )
682            .is_none()
683    {
684        // Only increment the metric if the client was not already blocked
685        debug!("Adding client {:?} to blocklist", client);
686        metrics.connection_ip_blocklist_len.inc();
687    }
688    if let Some(client) = block_proxied_client
689        && blocklists
690            .proxied_clients
691            .insert(
692                client,
693                SystemTime::now() + Duration::from_secs(*proxy_blocklist_ttl_sec),
694            )
695            .is_none()
696    {
697        // Only increment the metric if the client was not already blocked
698        debug!("Adding proxied client {:?} to blocklist", client);
699        metrics.proxy_ip_blocklist_len.inc();
700    }
701}
702
703async fn delegate_policy_response(
704    response: PolicyResponse,
705    policy_config: &PolicyConfig,
706    node_fw_client: &NodeFWClient,
707    destination_port: u16,
708    metrics: Arc<TrafficControllerMetrics>,
709) -> Result<(), reqwest::Error> {
710    let PolicyResponse {
711        block_client,
712        block_proxied_client,
713    } = response;
714    let PolicyConfig {
715        connection_blocklist_ttl_sec,
716        proxy_blocklist_ttl_sec,
717        ..
718    } = policy_config;
719    let mut addresses = vec![];
720    if let Some(client_id) = block_client {
721        debug!("Delegating client blocking to firewall");
722        addresses.push(BlockAddress {
723            source_address: client_id.to_string(),
724            destination_port,
725            ttl: *connection_blocklist_ttl_sec,
726        });
727    }
728    if let Some(ip) = block_proxied_client {
729        debug!("Delegating proxied client blocking to firewall");
730        addresses.push(BlockAddress {
731            source_address: ip.to_string(),
732            destination_port,
733            ttl: *proxy_blocklist_ttl_sec,
734        });
735    }
736    if addresses.is_empty() {
737        Ok(())
738    } else {
739        metrics
740            .blocks_delegated_to_firewall
741            .inc_by(addresses.len() as u64);
742        match node_fw_client
743            .block_addresses(BlockAddresses { addresses })
744            .await
745        {
746            Ok(()) => Ok(()),
747            Err(err) => {
748                metrics.firewall_delegation_request_fail.inc();
749                Err(err)
750            }
751        }
752    }
753}
754
755#[derive(Debug, Clone)]
756pub struct TrafficSimMetrics {
757    pub num_requests: u64,
758    pub num_blocked: u64,
759    pub time_to_first_block: Option<Duration>,
760    pub abs_time_to_first_block: Option<Duration>,
761    pub total_time_blocked: Duration,
762    pub num_blocklist_adds: u64,
763}
764
765impl Default for TrafficSimMetrics {
766    fn default() -> Self {
767        Self {
768            num_requests: 0,
769            num_blocked: 0,
770            time_to_first_block: None,
771            abs_time_to_first_block: None,
772            total_time_blocked: Duration::from_micros(0),
773            num_blocklist_adds: 0,
774        }
775    }
776}
777
778impl Add for TrafficSimMetrics {
779    type Output = Self;
780
781    fn add(self, other: Self) -> Self {
782        Self {
783            num_requests: self.num_requests + other.num_requests,
784            num_blocked: self.num_blocked + other.num_blocked,
785            time_to_first_block: match (self.time_to_first_block, other.time_to_first_block) {
786                (Some(a), Some(b)) => Some(a + b),
787                (Some(a), None) => Some(a),
788                (None, Some(b)) => Some(b),
789                (None, None) => None,
790            },
791            abs_time_to_first_block: match (
792                self.abs_time_to_first_block,
793                other.abs_time_to_first_block,
794            ) {
795                (Some(a), Some(b)) => Some(a.min(b)),
796                (Some(a), None) => Some(a),
797                (None, Some(b)) => Some(b),
798                (None, None) => None,
799            },
800            total_time_blocked: self.total_time_blocked + other.total_time_blocked,
801            num_blocklist_adds: self.num_blocklist_adds + other.num_blocklist_adds,
802        }
803    }
804}
805
806pub struct TrafficSim {
807    pub traffic_controller: TrafficController,
808}
809
810impl TrafficSim {
811    pub async fn run(
812        policy: PolicyConfig,
813        num_clients: u8,
814        per_client_tps: usize,
815        duration: Duration,
816        report: bool,
817    ) -> TrafficSimMetrics {
818        assert!(
819            per_client_tps <= 10_000,
820            "per_client_tps must be less than 10,000. For higher values, increase num_clients"
821        );
822        assert!(num_clients < 20, "num_clients must be greater than 0");
823        assert!(num_clients > 0);
824        assert!(per_client_tps > 0);
825        assert!(duration.as_secs() > 0);
826
827        let controller = TrafficController::init_for_test(policy.clone(), None).await;
828        let tasks = (0..num_clients).map(|task_num| {
829            tokio::spawn(Self::run_single_client(
830                controller.clone(),
831                duration,
832                task_num,
833                per_client_tps,
834            ))
835        });
836
837        let status_task = if report {
838            Some(tokio::spawn(async move {
839                println!(
840                    "Running naive traffic simulation for {} seconds",
841                    duration.as_secs()
842                );
843                println!("Policy: {:#?}", policy);
844                println!("Num clients: {}", num_clients);
845                println!("TPS per client: {}", per_client_tps);
846                println!(
847                    "Target total TPS: {}",
848                    per_client_tps * num_clients as usize
849                );
850                println!("\n");
851                for _ in 0..duration.as_secs() {
852                    print!(".");
853                    tokio::time::sleep(Duration::from_secs(1)).await;
854                }
855                println!();
856            }))
857        } else {
858            None
859        };
860
861        let metrics = futures::future::join_all(tasks).await.into_iter().fold(
862            TrafficSimMetrics::default(),
863            |acc, run_client_ret| {
864                if let Ok(metrics) = run_client_ret {
865                    acc + metrics
866                } else {
867                    error!(
868                        "Error running traffic sim client: {:?}",
869                        run_client_ret.err()
870                    );
871                    acc
872                }
873            },
874        );
875
876        if report {
877            status_task.unwrap().await.unwrap();
878            Self::report_metrics(metrics.clone(), duration, per_client_tps, num_clients);
879        }
880        metrics
881    }
882
883    async fn run_single_client(
884        controller: TrafficController,
885        duration: Duration,
886        task_num: u8,
887        per_client_tps: usize,
888    ) -> TrafficSimMetrics {
889        // Do an initial sleep for a random amount of time to smooth
890        // out the traffic. This shouldn't be strictly necessary and
891        // we can remove if we want more determinism
892        let sleep_time = Duration::from_micros(rand::thread_rng().gen_range(0..100));
893        tokio::time::sleep(sleep_time).await;
894
895        // collectors
896        let mut num_requests = 0;
897        let mut num_blocked = 0;
898        let mut time_to_first_block = None;
899        let mut total_time_blocked = Duration::from_micros(0);
900        let mut num_blocklist_adds = 0;
901        // state variables
902        let mut currently_blocked = false;
903        let mut time_blocked_start = Instant::now();
904        let start = Instant::now();
905
906        while start.elapsed() < duration {
907            let client = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num)));
908            let allowed = controller.check(&client, &None).await;
909            if allowed {
910                if currently_blocked {
911                    total_time_blocked += time_blocked_start.elapsed();
912                    currently_blocked = false;
913                }
914                controller.tally(TrafficTally::new(
915                    client,
916                    // TODO add proxy IP for testing
917                    None,
918                    // TODO add weight adjustments
919                    None,
920                    Weight::one(),
921                ));
922            } else {
923                if !currently_blocked {
924                    time_blocked_start = Instant::now();
925                    currently_blocked = true;
926                    num_blocklist_adds += 1;
927                    if time_to_first_block.is_none() {
928                        time_to_first_block = Some(start.elapsed());
929                    }
930                }
931                num_blocked += 1;
932            }
933            num_requests += 1;
934            tokio::time::sleep(Duration::from_micros(1_000_000 / per_client_tps as u64)).await;
935        }
936        TrafficSimMetrics {
937            num_requests,
938            num_blocked,
939            time_to_first_block,
940            abs_time_to_first_block: time_to_first_block,
941            total_time_blocked,
942            num_blocklist_adds,
943        }
944    }
945
946    fn report_metrics(
947        metrics: TrafficSimMetrics,
948        duration: Duration,
949        per_client_tps: usize,
950        num_clients: u8,
951    ) {
952        println!("TrafficSim metrics:");
953        println!("-------------------");
954        // The below two should be near equal
955        println!(
956            "Num expected requests: {}",
957            per_client_tps * (num_clients as usize) * duration.as_secs() as usize
958        );
959        println!("Num actual requests: {}", metrics.num_requests);
960        // This reflects the number of requests that were blocked, but note that once a client
961        // is added to the blocklist, all subsequent requests from that client are blocked
962        // until ttl is expired.
963        println!("Num blocked requests: {}", metrics.num_blocked);
964        // This metric on the other hand reflects the number of times a client was added to the blocklist
965        // and thus can be compared with the expectation based on the policy block threshold and ttl
966        println!(
967            "Num times added to blocklist: {}",
968            metrics.num_blocklist_adds
969        );
970        // This averages the duration for the first request to be blocked across all clients,
971        // which is useful for understanding if the policy is rate limiting based on expectation
972        let avg_first_block_time = metrics
973            .time_to_first_block
974            .map(|ttf| ttf / num_clients as u32);
975        println!("Average time to first block: {:?}", avg_first_block_time);
976        // This is the time it took for the first request to be blocked across all clients,
977        // and is instead more useful for understanding false positives in terms of rate and magnitude.
978        println!(
979            "Abolute time to first block (across all clients): {:?}",
980            metrics.abs_time_to_first_block
981        );
982        // Useful for ensuring that TTL is respected
983        let avg_time_blocked = if metrics.num_blocklist_adds > 0 {
984            metrics.total_time_blocked.as_millis() as u64 / metrics.num_blocklist_adds
985        } else {
986            0
987        };
988        println!(
989            "Average time blocked (ttl): {:?}",
990            Duration::from_millis(avg_time_blocked)
991        );
992    }
993}
994
995pub fn parse_ip(ip: &str) -> Option<IpAddr> {
996    ip.parse::<IpAddr>().ok().or_else(|| {
997        ip.parse::<SocketAddr>()
998            .ok()
999            .map(|socket_addr| socket_addr.ip())
1000            .or_else(|| {
1001                error!("Failed to parse value of {:?} to ip address or socket.", ip,);
1002                None
1003            })
1004    })
1005}