1pub mod metrics;
5pub mod nodefw_client;
6pub mod nodefw_test_server;
7pub mod policies;
8
9use dashmap::DashMap;
10use fs::File;
11use mysten_common::fatal;
12use prometheus::IntGauge;
13use std::fs;
14use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15use std::ops::Add;
16use std::sync::Arc;
17use sui_types::error::{SuiError, SuiErrorKind};
18
19use self::metrics::TrafficControllerMetrics;
20use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses, NodeFWClient};
21use crate::traffic_controller::policies::{
22 Policy, PolicyResponse, TrafficControlPolicy, TrafficTally,
23};
24use mysten_metrics::spawn_monitored_task;
25use parking_lot::Mutex as ParkingLotMutex;
26use rand::Rng;
27use std::fmt::Debug;
28use std::time::{Duration, Instant, SystemTime};
29use sui_types::traffic_control::{
30 PolicyConfig, PolicyType, RemoteFirewallConfig, TrafficControlReconfigParams, Weight,
31};
32use tokio::sync::mpsc::error::TrySendError;
33use tokio::sync::{Mutex, RwLock, mpsc};
34use tracing::{debug, error, info, trace, warn};
35
36pub const METRICS_INTERVAL_SECS: u64 = 2;
37pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 300;
38
39type Blocklist = Arc<DashMap<IpAddr, SystemTime>>;
40
41#[derive(Clone)]
42pub struct Blocklists {
43 clients: Blocklist,
44 proxied_clients: Blocklist,
45}
46
47#[derive(Clone)]
48pub enum Acl {
49 Blocklists(Blocklists),
50 Allowlist(Vec<IpAddr>),
55}
56
57#[derive(Clone)]
58pub struct TrafficController {
59 tally_channel: Arc<ParkingLotMutex<Option<mpsc::Sender<TrafficTally>>>>,
60 acl: Acl,
61 metrics: Arc<TrafficControllerMetrics>,
62 spam_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
63 error_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
64 policy_config: Arc<RwLock<PolicyConfig>>,
65 fw_config: Option<RemoteFirewallConfig>,
66}
67
68impl Debug for TrafficController {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("TrafficController")
76 .field(
77 "connection_ip_blocklist_len",
78 &self.metrics.connection_ip_blocklist_len.get(),
79 )
80 .field(
81 "proxy_ip_blocklist_len",
82 &self.metrics.proxy_ip_blocklist_len.get(),
83 )
84 .finish()
85 }
86}
87
88impl TrafficController {
89 pub async fn init(
90 policy_config: PolicyConfig,
91 metrics: Arc<TrafficControllerMetrics>,
92 fw_config: Option<RemoteFirewallConfig>,
93 ) -> Self {
94 metrics.dry_run_enabled.set(policy_config.dry_run as i64);
95 match policy_config.allow_list.clone() {
96 Some(allow_list) => {
97 let allowlist = allow_list
98 .into_iter()
99 .map(|ip_str| {
100 parse_ip(&ip_str).unwrap_or_else(|| {
101 fatal!("Failed to parse allowlist IP address: {:?}", ip_str)
102 })
103 })
104 .collect();
105 Self {
106 tally_channel: Arc::new(ParkingLotMutex::new(None)),
107 acl: Acl::Allowlist(allowlist),
108 metrics,
109 policy_config: Arc::new(RwLock::new(policy_config)),
110 fw_config,
111 spam_policy: None,
112 error_policy: None,
113 }
114 }
115 None => {
116 let spam_policy = Arc::new(Mutex::new(
117 TrafficControlPolicy::from_spam_config(policy_config.clone()).await,
118 ));
119 let error_policy = Arc::new(Mutex::new(
120 TrafficControlPolicy::from_error_config(policy_config.clone()).await,
121 ));
122 let this = Self {
123 tally_channel: Arc::new(ParkingLotMutex::new(None)),
124 acl: Acl::Blocklists(Blocklists {
125 clients: Arc::new(DashMap::new()),
126 proxied_clients: Arc::new(DashMap::new()),
127 }),
128 metrics,
129 policy_config: Arc::new(RwLock::new(policy_config)),
130 fw_config,
131 spam_policy: Some(spam_policy),
132 error_policy: Some(error_policy),
133 };
134 this.spawn().await;
135 this
136 }
137 }
138 }
139
140 pub async fn init_for_test(
141 policy_config: PolicyConfig,
142 fw_config: Option<RemoteFirewallConfig>,
143 ) -> Self {
144 let metrics = Arc::new(TrafficControllerMetrics::new(&prometheus::Registry::new()));
145 Self::init(policy_config, metrics, fw_config).await
146 }
147
148 async fn spawn(&self) {
149 let policy_config = { self.policy_config.read().await.clone() };
150 Self::set_policy_config_metrics(&policy_config, self.metrics.clone());
151 let (tx, rx) = mpsc::channel(policy_config.channel_capacity);
152 let mem_drainfile_present = self
157 .fw_config
158 .as_ref()
159 .map(|config| config.drain_path.exists())
160 .unwrap_or(false);
161 self.metrics
162 .deadmans_switch_enabled
163 .set(mem_drainfile_present as i64);
164 let blocklists = match self.acl.clone() {
165 Acl::Blocklists(blocklists) => blocklists,
166 Acl::Allowlist(_) => fatal!("Allowlist ACL should not exist on spawn"),
167 };
168 let tally_loop_blocklists = blocklists.clone();
169 let clear_loop_blocklists = blocklists.clone();
170 let tally_loop_metrics = self.metrics.clone();
171 let clear_loop_metrics = self.metrics.clone();
172 let tally_loop_policy_config = policy_config.clone();
173 let tally_loop_fw_config = self.fw_config.clone();
174
175 let spam_policy = self
176 .spam_policy
177 .clone()
178 .expect("spam policy should exist on spawn");
179 let error_policy = self
180 .error_policy
181 .clone()
182 .expect("error policy should exist on spawn");
183 let spam_policy_clone = spam_policy.clone();
184 let error_policy_clone = error_policy.clone();
185
186 spawn_monitored_task!(run_tally_loop(
187 rx,
188 tally_loop_policy_config,
189 spam_policy_clone,
190 error_policy_clone,
191 tally_loop_fw_config,
192 tally_loop_blocklists,
193 tally_loop_metrics,
194 mem_drainfile_present,
195 ));
196 spawn_monitored_task!(run_clear_blocklists_loop(
197 clear_loop_blocklists,
198 clear_loop_metrics,
199 ));
200 self.open_tally_channel(tx);
201 }
202
203 pub async fn get_current_state(&self) -> TrafficControlReconfigParams {
204 let mut result = TrafficControlReconfigParams {
205 error_threshold: None,
206 spam_threshold: None,
207 dry_run: None,
208 };
209
210 if let Some(error_policy) = self.error_policy.as_ref()
211 && let TrafficControlPolicy::FreqThreshold(ref policy) = *error_policy.lock().await
212 {
213 result.error_threshold = Some(policy.client_threshold);
214 }
215
216 if let Some(spam_policy) = self.spam_policy.as_ref()
217 && let TrafficControlPolicy::FreqThreshold(ref policy) = *spam_policy.lock().await
218 {
219 result.spam_threshold = Some(policy.client_threshold);
220 }
221
222 result.dry_run = Some(self.policy_config.read().await.dry_run);
223 result
224 }
225
226 pub async fn admin_reconfigure(
227 &self,
228 params: TrafficControlReconfigParams,
229 ) -> Result<TrafficControlReconfigParams, SuiError> {
230 let TrafficControlReconfigParams {
231 error_threshold,
232 spam_threshold,
233 dry_run,
234 } = params;
235 if let Some(error_threshold) = error_threshold {
236 self.metrics
237 .error_client_threshold
238 .set(error_threshold as i64);
239 Self::update_policy_threshold(
240 self.error_policy.as_ref().unwrap(),
241 error_threshold,
242 dry_run,
243 )
244 .await?;
245 }
246 if let Some(spam_threshold) = spam_threshold {
247 self.metrics
248 .spam_client_threshold
249 .set(spam_threshold as i64);
250 Self::update_policy_threshold(
251 self.spam_policy.as_ref().unwrap(),
252 spam_threshold,
253 dry_run,
254 )
255 .await?;
256 }
257 if let Some(dry_run) = dry_run {
258 self.metrics.dry_run_enabled.set(dry_run as i64);
259 self.policy_config.write().await.dry_run = dry_run;
260 }
261
262 Ok(self.get_current_state().await)
263 }
264
265 async fn update_policy_threshold(
266 policy: &Arc<Mutex<TrafficControlPolicy>>,
267 threshold: u64,
268 dry_run: Option<bool>,
269 ) -> Result<(), SuiError> {
270 match *policy.lock().await {
271 TrafficControlPolicy::FreqThreshold(ref mut policy) => {
272 policy.client_threshold = threshold;
273 if let Some(dry_run) = dry_run {
274 policy.config.dry_run = dry_run;
275 }
276 Ok(())
277 }
278 TrafficControlPolicy::TestNConnIP(ref mut policy) => {
279 policy.threshold = threshold;
280 if let Some(dry_run) = dry_run {
281 policy.config.dry_run = dry_run;
282 }
283 Ok(())
284 }
285 _ => Err(SuiErrorKind::InvalidAdminRequest(
286 "Unsupported prior policy type during traffic control reconfiguration".to_string(),
287 )
288 .into()),
289 }
290 }
291
292 fn open_tally_channel(&self, tx: mpsc::Sender<TrafficTally>) {
293 self.tally_channel.lock().replace(tx);
294 }
295
296 fn set_policy_config_metrics(
297 policy_config: &PolicyConfig,
298 metrics: Arc<TrafficControllerMetrics>,
299 ) {
300 if let PolicyType::FreqThreshold(config) = &policy_config.spam_policy_type {
301 metrics
302 .spam_client_threshold
303 .set(config.client_threshold as i64);
304 metrics
305 .spam_proxied_client_threshold
306 .set(config.proxied_client_threshold as i64);
307 }
308 if let PolicyType::FreqThreshold(config) = &policy_config.error_policy_type {
309 metrics
310 .error_client_threshold
311 .set(config.client_threshold as i64);
312 metrics
313 .error_proxied_client_threshold
314 .set(config.proxied_client_threshold as i64);
315 }
316 }
317
318 pub fn tally(&self, tally: TrafficTally) {
319 if let Some(channel) = self.tally_channel.lock().as_ref() {
320 match channel.try_send(tally) {
326 Err(TrySendError::Full(_)) => {
327 warn!("TrafficController tally channel full, dropping tally");
328 self.metrics.tally_channel_overflow.inc();
329 }
333 Err(TrySendError::Closed(_)) => {
334 warn!("TrafficController tally channel closed unexpectedly");
335 }
336 Ok(_) => {}
337 }
338 } else {
339 warn!("TrafficController not yet accepting tally requests.");
340 }
341 }
342
343 pub async fn check(&self, client: &Option<IpAddr>, proxied_client: &Option<IpAddr>) -> bool {
345 let policy_config = { self.policy_config.read().await.clone() };
346
347 let allowed = match &self.acl {
348 Acl::Allowlist(allowlist) => client.is_none() || allowlist.contains(&client.unwrap()),
349 Acl::Blocklists(blocklists) => {
350 self.check_blocklists(blocklists, client, proxied_client)
351 .await
352 }
353 };
354
355 match (allowed, policy_config.dry_run) {
356 (true, _) => true,
358 (false, true) => {
360 debug!("Dry run mode: Blocked request from client {:?}", client);
361 self.record_blocked_request(client, true);
362 true
363 }
364 (false, false) => {
366 debug!("Blocked request from client {:?}", client);
367 self.record_blocked_request(client, false);
368 false
369 }
370 }
371 }
372
373 fn record_blocked_request(&self, client: &Option<IpAddr>, dry_run: bool) {
374 let dry_run_str = if dry_run { "true" } else { "false" };
375
376 let ip_label = if let Some(ip) = client {
378 let bucket = ip
379 .to_string()
380 .bytes()
381 .fold(0u8, |acc, b| acc.wrapping_add(b))
382 % 100;
383 let bucket_label = format!("bucket_{}", bucket);
384 trace!("IP {} maps to {}", ip, bucket_label);
385 bucket_label
386 } else {
387 "unknown".to_string()
388 };
389
390 self.metrics
391 .requests_blocked_at_protocol
392 .with_label_values(&[dry_run_str, &ip_label])
393 .inc();
394 }
395
396 async fn check_blocklists(
398 &self,
399 blocklists: &Blocklists,
400 client: &Option<IpAddr>,
401 proxied_client: &Option<IpAddr>,
402 ) -> bool {
403 let client_check = self.check_and_clear_blocklist(
404 client,
405 blocklists.clients.clone(),
406 &self.metrics.connection_ip_blocklist_len,
407 );
408 let proxied_client_check = self.check_and_clear_blocklist(
409 proxied_client,
410 blocklists.proxied_clients.clone(),
411 &self.metrics.proxy_ip_blocklist_len,
412 );
413 let (client_check, proxied_client_check) =
414 futures::future::join(client_check, proxied_client_check).await;
415 client_check && proxied_client_check
416 }
417
418 async fn check_and_clear_blocklist(
419 &self,
420 client: &Option<IpAddr>,
421 blocklist: Blocklist,
422 blocklist_len_gauge: &IntGauge,
423 ) -> bool {
424 let client = match client {
425 Some(client) => client,
426 None => return true,
427 };
428 let now = SystemTime::now();
429 let (should_block, should_remove) = {
432 match blocklist.get(client) {
433 Some(expiration) if now >= *expiration => (false, true),
434 None => (false, false),
435 _ => (true, false),
436 }
437 };
438 if should_remove && blocklist.remove(client).is_some() {
439 blocklist_len_gauge.dec();
440 }
441 !should_block
442 }
443}
444
445async fn run_clear_blocklists_loop(blocklists: Blocklists, metrics: Arc<TrafficControllerMetrics>) {
452 loop {
453 tokio::time::sleep(Duration::from_secs(3)).await;
454 let now = SystemTime::now();
455 blocklists.clients.retain(|_, expiration| now < *expiration);
456 blocklists
457 .proxied_clients
458 .retain(|_, expiration| now < *expiration);
459 metrics
460 .connection_ip_blocklist_len
461 .set(blocklists.clients.len() as i64);
462 metrics
463 .proxy_ip_blocklist_len
464 .set(blocklists.proxied_clients.len() as i64);
465 }
466}
467
468async fn run_tally_loop(
469 mut receiver: mpsc::Receiver<TrafficTally>,
470 policy_config: PolicyConfig,
471 spam_policy: Arc<Mutex<TrafficControlPolicy>>,
472 error_policy: Arc<Mutex<TrafficControlPolicy>>,
473 fw_config: Option<RemoteFirewallConfig>,
474 blocklists: Blocklists,
475 metrics: Arc<TrafficControllerMetrics>,
476 mut mem_drainfile_present: bool,
477) {
478 let spam_blocklists = Arc::new(blocklists.clone());
479 let error_blocklists = Arc::new(blocklists);
480 let node_fw_client = fw_config
481 .as_ref()
482 .map(|fw_config| NodeFWClient::new(fw_config.remote_fw_url.clone()));
483
484 let timeout = fw_config
485 .as_ref()
486 .map(|fw_config| fw_config.drain_timeout_secs)
487 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
488 let mut metric_timer = Instant::now();
489
490 loop {
491 tokio::select! {
492 received = receiver.recv() => {
493 match received {
494 Some(tally) => {
495 let method = tally.method.as_deref().unwrap_or("unknown");
497 metrics.tallies
498 .with_label_values(&[method])
499 .inc();
500
501 if let Err(err) = handle_spam_tally(
503 spam_policy.clone(),
504 &policy_config,
505 &node_fw_client,
506 &fw_config,
507 tally.clone(),
508 spam_blocklists.clone(),
509 metrics.clone(),
510 mem_drainfile_present,
511 )
512 .await {
513 warn!("Error handling spam tally: {}", err);
514 }
515 if let Err(err) = handle_error_tally(
516 error_policy.clone(),
517 &policy_config,
518 &node_fw_client,
519 &fw_config,
520 tally,
521 error_blocklists.clone(),
522 metrics.clone(),
523 mem_drainfile_present,
524 )
525 .await {
526 warn!("Error handling error tally: {}", err);
527 }
528 }
529 None => {
530 info!("TrafficController tally channel closed by all senders");
531 return;
532 }
533 }
534 }
535 _ = tokio::time::sleep(tokio::time::Duration::from_secs(timeout)) => {
537 if let Some(fw_config) = &fw_config {
538 error!("No traffic tallies received in {} seconds.", timeout);
539 if mem_drainfile_present {
540 continue;
541 }
542 if !fw_config.drain_path.exists() {
543 mem_drainfile_present = true;
544 warn!("Draining Node firewall.");
545 File::create(&fw_config.drain_path)
546 .expect("Failed to touch nodefw drain file");
547 metrics.deadmans_switch_enabled.set(1);
548 }
549 }
550 }
551 }
552
553 if metric_timer.elapsed() > Duration::from_secs(METRICS_INTERVAL_SECS) {
556 if let TrafficControlPolicy::FreqThreshold(ref spam_policy) = *spam_policy.lock().await
557 {
558 if let Some(highest_direct_rate) = spam_policy.highest_direct_rate() {
559 metrics
560 .highest_direct_spam_rate
561 .set(highest_direct_rate.0 as i64);
562 debug!("Recent highest direct spam rate: {:?}", highest_direct_rate);
563 }
564 if let Some(highest_proxied_rate) = spam_policy.highest_proxied_rate() {
565 metrics
566 .highest_proxied_spam_rate
567 .set(highest_proxied_rate.0 as i64);
568 debug!(
569 "Recent highest proxied spam rate: {:?}",
570 highest_proxied_rate
571 );
572 }
573 }
574 if let TrafficControlPolicy::FreqThreshold(ref error_policy) =
575 *error_policy.lock().await
576 {
577 if let Some(highest_direct_rate) = error_policy.highest_direct_rate() {
578 metrics
579 .highest_direct_error_rate
580 .set(highest_direct_rate.0 as i64);
581 debug!(
582 "Recent highest direct error rate: {:?}",
583 highest_direct_rate
584 );
585 }
586 if let Some(highest_proxied_rate) = error_policy.highest_proxied_rate() {
587 metrics
588 .highest_proxied_error_rate
589 .set(highest_proxied_rate.0 as i64);
590 debug!(
591 "Recent highest proxied error rate: {:?}",
592 highest_proxied_rate
593 );
594 }
595 }
596 metric_timer = Instant::now();
597 }
598 }
599}
600
601async fn handle_error_tally(
602 policy: Arc<Mutex<TrafficControlPolicy>>,
603 policy_config: &PolicyConfig,
604 nodefw_client: &Option<NodeFWClient>,
605 fw_config: &Option<RemoteFirewallConfig>,
606 tally: TrafficTally,
607 blocklists: Arc<Blocklists>,
608 metrics: Arc<TrafficControllerMetrics>,
609 mem_drainfile_present: bool,
610) -> Result<(), reqwest::Error> {
611 let Some((error_weight, error_type)) = tally.clone().error_info else {
612 return Ok(());
613 };
614 if !error_weight.is_sampled() {
615 return Ok(());
616 }
617 trace!(
618 "Handling error_type {:?} from client {:?}",
619 error_type, tally.direct,
620 );
621 metrics
622 .tally_error_types
623 .with_label_values(&[error_type.as_str()])
624 .inc();
625 let resp = policy.lock().await.handle_tally(tally);
626 metrics.error_tally_handled.inc();
627 if let Some(fw_config) = fw_config
628 && fw_config.delegate_error_blocking
629 && !mem_drainfile_present
630 {
631 let client = nodefw_client
632 .as_ref()
633 .expect("Expected NodeFWClient for blocklist delegation");
634 return delegate_policy_response(
635 resp,
636 policy_config,
637 client,
638 fw_config.destination_port,
639 metrics.clone(),
640 )
641 .await;
642 }
643 handle_policy_response(resp, policy_config, blocklists, metrics).await;
644 Ok(())
645}
646
647async fn handle_spam_tally(
648 policy: Arc<Mutex<TrafficControlPolicy>>,
649 policy_config: &PolicyConfig,
650 nodefw_client: &Option<NodeFWClient>,
651 fw_config: &Option<RemoteFirewallConfig>,
652 tally: TrafficTally,
653 blocklists: Arc<Blocklists>,
654 metrics: Arc<TrafficControllerMetrics>,
655 mem_drainfile_present: bool,
656) -> Result<(), reqwest::Error> {
657 if !(tally.spam_weight.is_sampled() && policy_config.spam_sample_rate.is_sampled()) {
658 return Ok(());
659 }
660 let resp = policy.lock().await.handle_tally(tally.clone());
661 metrics.tally_handled.inc();
662 if let Some(fw_config) = fw_config
663 && fw_config.delegate_spam_blocking
664 && !mem_drainfile_present
665 {
666 let client = nodefw_client
667 .as_ref()
668 .expect("Expected NodeFWClient for blocklist delegation");
669 return delegate_policy_response(
670 resp,
671 policy_config,
672 client,
673 fw_config.destination_port,
674 metrics.clone(),
675 )
676 .await;
677 }
678 handle_policy_response(resp, policy_config, blocklists, metrics).await;
679 Ok(())
680}
681
682async fn handle_policy_response(
683 response: PolicyResponse,
684 policy_config: &PolicyConfig,
685 blocklists: Arc<Blocklists>,
686 metrics: Arc<TrafficControllerMetrics>,
687) {
688 let PolicyResponse {
689 block_client,
690 block_proxied_client,
691 } = response;
692 let PolicyConfig {
693 connection_blocklist_ttl_sec,
694 proxy_blocklist_ttl_sec,
695 ..
696 } = policy_config;
697 if let Some(client) = block_client
698 && blocklists
699 .clients
700 .insert(
701 client,
702 SystemTime::now() + Duration::from_secs(*connection_blocklist_ttl_sec),
703 )
704 .is_none()
705 {
706 debug!("Adding client {:?} to blocklist", client);
708 metrics.connection_ip_blocklist_len.inc();
709 }
710 if let Some(client) = block_proxied_client
711 && blocklists
712 .proxied_clients
713 .insert(
714 client,
715 SystemTime::now() + Duration::from_secs(*proxy_blocklist_ttl_sec),
716 )
717 .is_none()
718 {
719 debug!("Adding proxied client {:?} to blocklist", client);
721 metrics.proxy_ip_blocklist_len.inc();
722 }
723}
724
725async fn delegate_policy_response(
726 response: PolicyResponse,
727 policy_config: &PolicyConfig,
728 node_fw_client: &NodeFWClient,
729 destination_port: u16,
730 metrics: Arc<TrafficControllerMetrics>,
731) -> Result<(), reqwest::Error> {
732 let PolicyResponse {
733 block_client,
734 block_proxied_client,
735 } = response;
736 let PolicyConfig {
737 connection_blocklist_ttl_sec,
738 proxy_blocklist_ttl_sec,
739 ..
740 } = policy_config;
741 let mut addresses = vec![];
742 if let Some(client_id) = block_client {
743 debug!("Delegating client blocking to firewall");
744 addresses.push(BlockAddress {
745 source_address: client_id.to_string(),
746 destination_port,
747 ttl: *connection_blocklist_ttl_sec,
748 });
749 }
750 if let Some(ip) = block_proxied_client {
751 debug!("Delegating proxied client blocking to firewall");
752 addresses.push(BlockAddress {
753 source_address: ip.to_string(),
754 destination_port,
755 ttl: *proxy_blocklist_ttl_sec,
756 });
757 }
758 if addresses.is_empty() {
759 Ok(())
760 } else {
761 metrics
762 .blocks_delegated_to_firewall
763 .inc_by(addresses.len() as u64);
764 match node_fw_client
765 .block_addresses(BlockAddresses { addresses })
766 .await
767 {
768 Ok(()) => Ok(()),
769 Err(err) => {
770 metrics.firewall_delegation_request_fail.inc();
771 Err(err)
772 }
773 }
774 }
775}
776
777#[derive(Debug, Clone)]
778pub struct TrafficSimMetrics {
779 pub num_requests: u64,
780 pub num_blocked: u64,
781 pub time_to_first_block: Option<Duration>,
782 pub abs_time_to_first_block: Option<Duration>,
783 pub total_time_blocked: Duration,
784 pub num_blocklist_adds: u64,
785}
786
787impl Default for TrafficSimMetrics {
788 fn default() -> Self {
789 Self {
790 num_requests: 0,
791 num_blocked: 0,
792 time_to_first_block: None,
793 abs_time_to_first_block: None,
794 total_time_blocked: Duration::from_micros(0),
795 num_blocklist_adds: 0,
796 }
797 }
798}
799
800impl Add for TrafficSimMetrics {
801 type Output = Self;
802
803 fn add(self, other: Self) -> Self {
804 Self {
805 num_requests: self.num_requests + other.num_requests,
806 num_blocked: self.num_blocked + other.num_blocked,
807 time_to_first_block: match (self.time_to_first_block, other.time_to_first_block) {
808 (Some(a), Some(b)) => Some(a + b),
809 (Some(a), None) => Some(a),
810 (None, Some(b)) => Some(b),
811 (None, None) => None,
812 },
813 abs_time_to_first_block: match (
814 self.abs_time_to_first_block,
815 other.abs_time_to_first_block,
816 ) {
817 (Some(a), Some(b)) => Some(a.min(b)),
818 (Some(a), None) => Some(a),
819 (None, Some(b)) => Some(b),
820 (None, None) => None,
821 },
822 total_time_blocked: self.total_time_blocked + other.total_time_blocked,
823 num_blocklist_adds: self.num_blocklist_adds + other.num_blocklist_adds,
824 }
825 }
826}
827
828pub struct TrafficSim {
829 pub traffic_controller: TrafficController,
830}
831
832impl TrafficSim {
833 pub async fn run(
834 policy: PolicyConfig,
835 num_clients: u8,
836 per_client_tps: usize,
837 duration: Duration,
838 report: bool,
839 ) -> TrafficSimMetrics {
840 assert!(
841 per_client_tps <= 10_000,
842 "per_client_tps must be less than 10,000. For higher values, increase num_clients"
843 );
844 assert!(num_clients < 20, "num_clients must be greater than 0");
845 assert!(num_clients > 0);
846 assert!(per_client_tps > 0);
847 assert!(duration.as_secs() > 0);
848
849 let controller = TrafficController::init_for_test(policy.clone(), None).await;
850 let tasks = (0..num_clients).map(|task_num| {
851 tokio::spawn(Self::run_single_client(
852 controller.clone(),
853 duration,
854 task_num,
855 per_client_tps,
856 ))
857 });
858
859 let status_task = if report {
860 Some(tokio::spawn(async move {
861 println!(
862 "Running naive traffic simulation for {} seconds",
863 duration.as_secs()
864 );
865 println!("Policy: {:#?}", policy);
866 println!("Num clients: {}", num_clients);
867 println!("TPS per client: {}", per_client_tps);
868 println!(
869 "Target total TPS: {}",
870 per_client_tps * num_clients as usize
871 );
872 println!("\n");
873 for _ in 0..duration.as_secs() {
874 print!(".");
875 tokio::time::sleep(Duration::from_secs(1)).await;
876 }
877 println!();
878 }))
879 } else {
880 None
881 };
882
883 let metrics = futures::future::join_all(tasks).await.into_iter().fold(
884 TrafficSimMetrics::default(),
885 |acc, run_client_ret| {
886 if let Ok(metrics) = run_client_ret {
887 acc + metrics
888 } else {
889 error!(
890 "Error running traffic sim client: {:?}",
891 run_client_ret.err()
892 );
893 acc
894 }
895 },
896 );
897
898 if report {
899 status_task.unwrap().await.unwrap();
900 Self::report_metrics(metrics.clone(), duration, per_client_tps, num_clients);
901 }
902 metrics
903 }
904
905 async fn run_single_client(
906 controller: TrafficController,
907 duration: Duration,
908 task_num: u8,
909 per_client_tps: usize,
910 ) -> TrafficSimMetrics {
911 let sleep_time = Duration::from_micros(rand::thread_rng().gen_range(0..100));
915 tokio::time::sleep(sleep_time).await;
916
917 let mut num_requests = 0;
919 let mut num_blocked = 0;
920 let mut time_to_first_block = None;
921 let mut total_time_blocked = Duration::from_micros(0);
922 let mut num_blocklist_adds = 0;
923 let mut currently_blocked = false;
925 let mut time_blocked_start = Instant::now();
926 let start = Instant::now();
927
928 while start.elapsed() < duration {
929 let client = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num)));
930 let allowed = controller.check(&client, &None).await;
931 if allowed {
932 if currently_blocked {
933 total_time_blocked += time_blocked_start.elapsed();
934 currently_blocked = false;
935 }
936 controller.tally(TrafficTally::new(
937 client,
938 None,
940 None,
942 Weight::one(),
943 ));
944 } else {
945 if !currently_blocked {
946 time_blocked_start = Instant::now();
947 currently_blocked = true;
948 num_blocklist_adds += 1;
949 if time_to_first_block.is_none() {
950 time_to_first_block = Some(start.elapsed());
951 }
952 }
953 num_blocked += 1;
954 }
955 num_requests += 1;
956 tokio::time::sleep(Duration::from_micros(1_000_000 / per_client_tps as u64)).await;
957 }
958 TrafficSimMetrics {
959 num_requests,
960 num_blocked,
961 time_to_first_block,
962 abs_time_to_first_block: time_to_first_block,
963 total_time_blocked,
964 num_blocklist_adds,
965 }
966 }
967
968 fn report_metrics(
969 metrics: TrafficSimMetrics,
970 duration: Duration,
971 per_client_tps: usize,
972 num_clients: u8,
973 ) {
974 println!("TrafficSim metrics:");
975 println!("-------------------");
976 println!(
978 "Num expected requests: {}",
979 per_client_tps * (num_clients as usize) * duration.as_secs() as usize
980 );
981 println!("Num actual requests: {}", metrics.num_requests);
982 println!("Num blocked requests: {}", metrics.num_blocked);
986 println!(
989 "Num times added to blocklist: {}",
990 metrics.num_blocklist_adds
991 );
992 let avg_first_block_time = metrics
995 .time_to_first_block
996 .map(|ttf| ttf / num_clients as u32);
997 println!("Average time to first block: {:?}", avg_first_block_time);
998 println!(
1001 "Abolute time to first block (across all clients): {:?}",
1002 metrics.abs_time_to_first_block
1003 );
1004 let avg_time_blocked = if metrics.num_blocklist_adds > 0 {
1006 metrics.total_time_blocked.as_millis() as u64 / metrics.num_blocklist_adds
1007 } else {
1008 0
1009 };
1010 println!(
1011 "Average time blocked (ttl): {:?}",
1012 Duration::from_millis(avg_time_blocked)
1013 );
1014 }
1015}
1016
1017pub fn parse_ip(ip: &str) -> Option<IpAddr> {
1018 ip.parse::<IpAddr>().ok().or_else(|| {
1019 ip.parse::<SocketAddr>()
1020 .ok()
1021 .map(|socket_addr| socket_addr.ip())
1022 .or_else(|| {
1023 error!("Failed to parse value of {:?} to ip address or socket.", ip,);
1024 None
1025 })
1026 })
1027}