1pub mod metrics;
5pub mod nodefw_client;
6pub mod nodefw_test_server;
7pub mod policies;
8
9use dashmap::DashMap;
10use fs::File;
11use mysten_common::fatal;
12use prometheus::IntGauge;
13use std::fs;
14use std::net::{IpAddr, Ipv4Addr, SocketAddr};
15use std::ops::Add;
16use std::sync::Arc;
17use sui_types::error::{SuiError, SuiErrorKind};
18
19use self::metrics::TrafficControllerMetrics;
20use crate::traffic_controller::nodefw_client::{BlockAddress, BlockAddresses, NodeFWClient};
21use crate::traffic_controller::policies::{
22 Policy, PolicyResponse, TrafficControlPolicy, TrafficTally,
23};
24use mysten_metrics::spawn_monitored_task;
25use parking_lot::Mutex as ParkingLotMutex;
26use rand::Rng;
27use std::fmt::Debug;
28use std::time::{Duration, Instant, SystemTime};
29use sui_types::traffic_control::{
30 PolicyConfig, PolicyType, RemoteFirewallConfig, TrafficControlReconfigParams, Weight,
31};
32use tokio::sync::mpsc::error::TrySendError;
33use tokio::sync::{Mutex, RwLock, mpsc};
34use tracing::{debug, error, info, trace, warn};
35
36pub const METRICS_INTERVAL_SECS: u64 = 2;
37pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 300;
38
39type Blocklist = Arc<DashMap<IpAddr, SystemTime>>;
40
41#[derive(Clone)]
42pub struct Blocklists {
43 clients: Blocklist,
44 proxied_clients: Blocklist,
45}
46
47#[derive(Clone)]
48pub enum Acl {
49 Blocklists(Blocklists),
50 Allowlist(Vec<IpAddr>),
55}
56
57#[derive(Clone)]
58pub struct TrafficController {
59 tally_channel: Arc<ParkingLotMutex<Option<mpsc::Sender<TrafficTally>>>>,
60 acl: Acl,
61 metrics: Arc<TrafficControllerMetrics>,
62 spam_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
63 error_policy: Option<Arc<Mutex<TrafficControlPolicy>>>,
64 policy_config: Arc<RwLock<PolicyConfig>>,
65 fw_config: Option<RemoteFirewallConfig>,
66}
67
68impl Debug for TrafficController {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("TrafficController")
76 .field(
77 "connection_ip_blocklist_len",
78 &self.metrics.connection_ip_blocklist_len.get(),
79 )
80 .field(
81 "proxy_ip_blocklist_len",
82 &self.metrics.proxy_ip_blocklist_len.get(),
83 )
84 .finish()
85 }
86}
87
88impl TrafficController {
89 pub async fn init(
90 policy_config: PolicyConfig,
91 metrics: Arc<TrafficControllerMetrics>,
92 fw_config: Option<RemoteFirewallConfig>,
93 ) -> Self {
94 metrics.dry_run_enabled.set(policy_config.dry_run as i64);
95 match policy_config.allow_list.clone() {
96 Some(allow_list) => {
97 let allowlist = allow_list
98 .into_iter()
99 .map(|ip_str| {
100 parse_ip(&ip_str).unwrap_or_else(|| {
101 fatal!("Failed to parse allowlist IP address: {:?}", ip_str)
102 })
103 })
104 .collect();
105 Self {
106 tally_channel: Arc::new(ParkingLotMutex::new(None)),
107 acl: Acl::Allowlist(allowlist),
108 metrics,
109 policy_config: Arc::new(RwLock::new(policy_config)),
110 fw_config,
111 spam_policy: None,
112 error_policy: None,
113 }
114 }
115 None => {
116 let spam_policy = Arc::new(Mutex::new(
117 TrafficControlPolicy::from_spam_config(policy_config.clone()).await,
118 ));
119 let error_policy = Arc::new(Mutex::new(
120 TrafficControlPolicy::from_error_config(policy_config.clone()).await,
121 ));
122 let this = Self {
123 tally_channel: Arc::new(ParkingLotMutex::new(None)),
124 acl: Acl::Blocklists(Blocklists {
125 clients: Arc::new(DashMap::new()),
126 proxied_clients: Arc::new(DashMap::new()),
127 }),
128 metrics,
129 policy_config: Arc::new(RwLock::new(policy_config)),
130 fw_config,
131 spam_policy: Some(spam_policy),
132 error_policy: Some(error_policy),
133 };
134 this.spawn().await;
135 this
136 }
137 }
138 }
139
140 pub async fn init_for_test(
141 policy_config: PolicyConfig,
142 fw_config: Option<RemoteFirewallConfig>,
143 ) -> Self {
144 let metrics = Arc::new(TrafficControllerMetrics::new(&prometheus::Registry::new()));
145 Self::init(policy_config, metrics, fw_config).await
146 }
147
148 async fn spawn(&self) {
149 let policy_config = { self.policy_config.read().await.clone() };
150 Self::set_policy_config_metrics(&policy_config, self.metrics.clone());
151 let (tx, rx) = mpsc::channel(policy_config.channel_capacity);
152 let mem_drainfile_present = self
157 .fw_config
158 .as_ref()
159 .map(|config| config.drain_path.exists())
160 .unwrap_or(false);
161 self.metrics
162 .deadmans_switch_enabled
163 .set(mem_drainfile_present as i64);
164 let blocklists = match self.acl.clone() {
165 Acl::Blocklists(blocklists) => blocklists,
166 Acl::Allowlist(_) => fatal!("Allowlist ACL should not exist on spawn"),
167 };
168 let tally_loop_blocklists = blocklists.clone();
169 let clear_loop_blocklists = blocklists.clone();
170 let tally_loop_metrics = self.metrics.clone();
171 let clear_loop_metrics = self.metrics.clone();
172 let tally_loop_policy_config = policy_config.clone();
173 let tally_loop_fw_config = self.fw_config.clone();
174
175 let spam_policy = self
176 .spam_policy
177 .clone()
178 .expect("spam policy should exist on spawn");
179 let error_policy = self
180 .error_policy
181 .clone()
182 .expect("error policy should exist on spawn");
183 let spam_policy_clone = spam_policy.clone();
184 let error_policy_clone = error_policy.clone();
185
186 spawn_monitored_task!(run_tally_loop(
187 rx,
188 tally_loop_policy_config,
189 spam_policy_clone,
190 error_policy_clone,
191 tally_loop_fw_config,
192 tally_loop_blocklists,
193 tally_loop_metrics,
194 mem_drainfile_present,
195 ));
196 spawn_monitored_task!(run_clear_blocklists_loop(
197 clear_loop_blocklists,
198 clear_loop_metrics,
199 ));
200 self.open_tally_channel(tx);
201 }
202
203 pub async fn get_current_state(&self) -> TrafficControlReconfigParams {
204 let mut result = TrafficControlReconfigParams {
205 error_threshold: None,
206 spam_threshold: None,
207 dry_run: None,
208 };
209
210 if let Some(error_policy) = self.error_policy.as_ref()
211 && let TrafficControlPolicy::FreqThreshold(ref policy) = *error_policy.lock().await
212 {
213 result.error_threshold = Some(policy.client_threshold);
214 }
215
216 if let Some(spam_policy) = self.spam_policy.as_ref()
217 && let TrafficControlPolicy::FreqThreshold(ref policy) = *spam_policy.lock().await
218 {
219 result.spam_threshold = Some(policy.client_threshold);
220 }
221
222 result.dry_run = Some(self.policy_config.read().await.dry_run);
223 result
224 }
225
226 pub async fn admin_reconfigure(
227 &self,
228 params: TrafficControlReconfigParams,
229 ) -> Result<TrafficControlReconfigParams, SuiError> {
230 let TrafficControlReconfigParams {
231 error_threshold,
232 spam_threshold,
233 dry_run,
234 } = params;
235 if let Some(error_threshold) = error_threshold {
236 self.metrics
237 .error_client_threshold
238 .set(error_threshold as i64);
239 Self::update_policy_threshold(
240 self.error_policy.as_ref().unwrap(),
241 error_threshold,
242 dry_run,
243 )
244 .await?;
245 }
246 if let Some(spam_threshold) = spam_threshold {
247 self.metrics
248 .spam_client_threshold
249 .set(spam_threshold as i64);
250 Self::update_policy_threshold(
251 self.spam_policy.as_ref().unwrap(),
252 spam_threshold,
253 dry_run,
254 )
255 .await?;
256 }
257 if let Some(dry_run) = dry_run {
258 self.metrics.dry_run_enabled.set(dry_run as i64);
259 self.policy_config.write().await.dry_run = dry_run;
260 }
261
262 Ok(self.get_current_state().await)
263 }
264
265 async fn update_policy_threshold(
266 policy: &Arc<Mutex<TrafficControlPolicy>>,
267 threshold: u64,
268 dry_run: Option<bool>,
269 ) -> Result<(), SuiError> {
270 match *policy.lock().await {
271 TrafficControlPolicy::FreqThreshold(ref mut policy) => {
272 policy.client_threshold = threshold;
273 if let Some(dry_run) = dry_run {
274 policy.config.dry_run = dry_run;
275 }
276 Ok(())
277 }
278 TrafficControlPolicy::TestNConnIP(ref mut policy) => {
279 policy.threshold = threshold;
280 if let Some(dry_run) = dry_run {
281 policy.config.dry_run = dry_run;
282 }
283 Ok(())
284 }
285 _ => Err(SuiErrorKind::InvalidAdminRequest(
286 "Unsupported prior policy type during traffic control reconfiguration".to_string(),
287 )
288 .into()),
289 }
290 }
291
292 fn open_tally_channel(&self, tx: mpsc::Sender<TrafficTally>) {
293 self.tally_channel.lock().replace(tx);
294 }
295
296 fn set_policy_config_metrics(
297 policy_config: &PolicyConfig,
298 metrics: Arc<TrafficControllerMetrics>,
299 ) {
300 if let PolicyType::FreqThreshold(config) = &policy_config.spam_policy_type {
301 metrics
302 .spam_client_threshold
303 .set(config.client_threshold as i64);
304 metrics
305 .spam_proxied_client_threshold
306 .set(config.proxied_client_threshold as i64);
307 }
308 if let PolicyType::FreqThreshold(config) = &policy_config.error_policy_type {
309 metrics
310 .error_client_threshold
311 .set(config.client_threshold as i64);
312 metrics
313 .error_proxied_client_threshold
314 .set(config.proxied_client_threshold as i64);
315 }
316 }
317
318 pub fn tally(&self, tally: TrafficTally) {
319 if let Some(channel) = self.tally_channel.lock().as_ref() {
320 match channel.try_send(tally) {
326 Err(TrySendError::Full(_)) => {
327 warn!("TrafficController tally channel full, dropping tally");
328 self.metrics.tally_channel_overflow.inc();
329 }
333 Err(TrySendError::Closed(_)) => {
334 warn!("TrafficController tally channel closed unexpectedly");
335 }
336 Ok(_) => {}
337 }
338 } else {
339 warn!("TrafficController not yet accepting tally requests.");
340 }
341 }
342
343 pub async fn check(&self, client: &Option<IpAddr>, proxied_client: &Option<IpAddr>) -> bool {
345 let policy_config = { self.policy_config.read().await.clone() };
346 let check_with_dry_run_maybe = |allowed| -> bool {
347 match (allowed, policy_config.dry_run) {
348 (true, _) => true,
350 (false, true) => {
352 debug!("Dry run mode: Blocked request from client {:?}", client);
353 self.metrics.num_dry_run_blocked_requests.inc();
354 true
355 }
356 (false, false) => {
358 debug!("Blocked request from client {:?}", client);
359 self.metrics.requests_blocked_at_protocol.inc();
360 false
361 }
362 }
363 };
364
365 match &self.acl {
366 Acl::Allowlist(allowlist) => {
367 let allowed = client.is_none() || allowlist.contains(&client.unwrap());
368 check_with_dry_run_maybe(allowed)
369 }
370 Acl::Blocklists(blocklists) => {
371 let allowed = self
372 .check_blocklists(blocklists, client, proxied_client)
373 .await;
374 check_with_dry_run_maybe(allowed)
375 }
376 }
377 }
378
379 async fn check_blocklists(
381 &self,
382 blocklists: &Blocklists,
383 client: &Option<IpAddr>,
384 proxied_client: &Option<IpAddr>,
385 ) -> bool {
386 let client_check = self.check_and_clear_blocklist(
387 client,
388 blocklists.clients.clone(),
389 &self.metrics.connection_ip_blocklist_len,
390 );
391 let proxied_client_check = self.check_and_clear_blocklist(
392 proxied_client,
393 blocklists.proxied_clients.clone(),
394 &self.metrics.proxy_ip_blocklist_len,
395 );
396 let (client_check, proxied_client_check) =
397 futures::future::join(client_check, proxied_client_check).await;
398 client_check && proxied_client_check
399 }
400
401 async fn check_and_clear_blocklist(
402 &self,
403 client: &Option<IpAddr>,
404 blocklist: Blocklist,
405 blocklist_len_gauge: &IntGauge,
406 ) -> bool {
407 let client = match client {
408 Some(client) => client,
409 None => return true,
410 };
411 let now = SystemTime::now();
412 let (should_block, should_remove) = {
415 match blocklist.get(client) {
416 Some(expiration) if now >= *expiration => (false, true),
417 None => (false, false),
418 _ => (true, false),
419 }
420 };
421 if should_remove && blocklist.remove(client).is_some() {
422 blocklist_len_gauge.dec();
423 }
424 !should_block
425 }
426}
427
428async fn run_clear_blocklists_loop(blocklists: Blocklists, metrics: Arc<TrafficControllerMetrics>) {
435 loop {
436 tokio::time::sleep(Duration::from_secs(3)).await;
437 let now = SystemTime::now();
438 blocklists.clients.retain(|_, expiration| now < *expiration);
439 blocklists
440 .proxied_clients
441 .retain(|_, expiration| now < *expiration);
442 metrics
443 .connection_ip_blocklist_len
444 .set(blocklists.clients.len() as i64);
445 metrics
446 .proxy_ip_blocklist_len
447 .set(blocklists.proxied_clients.len() as i64);
448 }
449}
450
451async fn run_tally_loop(
452 mut receiver: mpsc::Receiver<TrafficTally>,
453 policy_config: PolicyConfig,
454 spam_policy: Arc<Mutex<TrafficControlPolicy>>,
455 error_policy: Arc<Mutex<TrafficControlPolicy>>,
456 fw_config: Option<RemoteFirewallConfig>,
457 blocklists: Blocklists,
458 metrics: Arc<TrafficControllerMetrics>,
459 mut mem_drainfile_present: bool,
460) {
461 let spam_blocklists = Arc::new(blocklists.clone());
462 let error_blocklists = Arc::new(blocklists);
463 let node_fw_client = fw_config
464 .as_ref()
465 .map(|fw_config| NodeFWClient::new(fw_config.remote_fw_url.clone()));
466
467 let timeout = fw_config
468 .as_ref()
469 .map(|fw_config| fw_config.drain_timeout_secs)
470 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS);
471 let mut metric_timer = Instant::now();
472
473 loop {
474 tokio::select! {
475 received = receiver.recv() => {
476 metrics.tallies.inc();
477 match received {
478 Some(tally) => {
479 if let Err(err) = handle_spam_tally(
481 spam_policy.clone(),
482 &policy_config,
483 &node_fw_client,
484 &fw_config,
485 tally.clone(),
486 spam_blocklists.clone(),
487 metrics.clone(),
488 mem_drainfile_present,
489 )
490 .await {
491 warn!("Error handling spam tally: {}", err);
492 }
493 if let Err(err) = handle_error_tally(
494 error_policy.clone(),
495 &policy_config,
496 &node_fw_client,
497 &fw_config,
498 tally,
499 error_blocklists.clone(),
500 metrics.clone(),
501 mem_drainfile_present,
502 )
503 .await {
504 warn!("Error handling error tally: {}", err);
505 }
506 }
507 None => {
508 info!("TrafficController tally channel closed by all senders");
509 return;
510 }
511 }
512 }
513 _ = tokio::time::sleep(tokio::time::Duration::from_secs(timeout)) => {
515 if let Some(fw_config) = &fw_config {
516 error!("No traffic tallies received in {} seconds.", timeout);
517 if mem_drainfile_present {
518 continue;
519 }
520 if !fw_config.drain_path.exists() {
521 mem_drainfile_present = true;
522 warn!("Draining Node firewall.");
523 File::create(&fw_config.drain_path)
524 .expect("Failed to touch nodefw drain file");
525 metrics.deadmans_switch_enabled.set(1);
526 }
527 }
528 }
529 }
530
531 if metric_timer.elapsed() > Duration::from_secs(METRICS_INTERVAL_SECS) {
534 if let TrafficControlPolicy::FreqThreshold(ref spam_policy) = *spam_policy.lock().await
535 {
536 if let Some(highest_direct_rate) = spam_policy.highest_direct_rate() {
537 metrics
538 .highest_direct_spam_rate
539 .set(highest_direct_rate.0 as i64);
540 debug!("Recent highest direct spam rate: {:?}", highest_direct_rate);
541 }
542 if let Some(highest_proxied_rate) = spam_policy.highest_proxied_rate() {
543 metrics
544 .highest_proxied_spam_rate
545 .set(highest_proxied_rate.0 as i64);
546 debug!(
547 "Recent highest proxied spam rate: {:?}",
548 highest_proxied_rate
549 );
550 }
551 }
552 if let TrafficControlPolicy::FreqThreshold(ref error_policy) =
553 *error_policy.lock().await
554 {
555 if let Some(highest_direct_rate) = error_policy.highest_direct_rate() {
556 metrics
557 .highest_direct_error_rate
558 .set(highest_direct_rate.0 as i64);
559 debug!(
560 "Recent highest direct error rate: {:?}",
561 highest_direct_rate
562 );
563 }
564 if let Some(highest_proxied_rate) = error_policy.highest_proxied_rate() {
565 metrics
566 .highest_proxied_error_rate
567 .set(highest_proxied_rate.0 as i64);
568 debug!(
569 "Recent highest proxied error rate: {:?}",
570 highest_proxied_rate
571 );
572 }
573 }
574 metric_timer = Instant::now();
575 }
576 }
577}
578
579async fn handle_error_tally(
580 policy: Arc<Mutex<TrafficControlPolicy>>,
581 policy_config: &PolicyConfig,
582 nodefw_client: &Option<NodeFWClient>,
583 fw_config: &Option<RemoteFirewallConfig>,
584 tally: TrafficTally,
585 blocklists: Arc<Blocklists>,
586 metrics: Arc<TrafficControllerMetrics>,
587 mem_drainfile_present: bool,
588) -> Result<(), reqwest::Error> {
589 let Some((error_weight, error_type)) = tally.clone().error_info else {
590 return Ok(());
591 };
592 if !error_weight.is_sampled() {
593 return Ok(());
594 }
595 trace!(
596 "Handling error_type {:?} from client {:?}",
597 error_type, tally.direct,
598 );
599 metrics
600 .tally_error_types
601 .with_label_values(&[error_type.as_str()])
602 .inc();
603 let resp = policy.lock().await.handle_tally(tally);
604 metrics.error_tally_handled.inc();
605 if let Some(fw_config) = fw_config
606 && fw_config.delegate_error_blocking
607 && !mem_drainfile_present
608 {
609 let client = nodefw_client
610 .as_ref()
611 .expect("Expected NodeFWClient for blocklist delegation");
612 return delegate_policy_response(
613 resp,
614 policy_config,
615 client,
616 fw_config.destination_port,
617 metrics.clone(),
618 )
619 .await;
620 }
621 handle_policy_response(resp, policy_config, blocklists, metrics).await;
622 Ok(())
623}
624
625async fn handle_spam_tally(
626 policy: Arc<Mutex<TrafficControlPolicy>>,
627 policy_config: &PolicyConfig,
628 nodefw_client: &Option<NodeFWClient>,
629 fw_config: &Option<RemoteFirewallConfig>,
630 tally: TrafficTally,
631 blocklists: Arc<Blocklists>,
632 metrics: Arc<TrafficControllerMetrics>,
633 mem_drainfile_present: bool,
634) -> Result<(), reqwest::Error> {
635 if !(tally.spam_weight.is_sampled() && policy_config.spam_sample_rate.is_sampled()) {
636 return Ok(());
637 }
638 let resp = policy.lock().await.handle_tally(tally.clone());
639 metrics.tally_handled.inc();
640 if let Some(fw_config) = fw_config
641 && fw_config.delegate_spam_blocking
642 && !mem_drainfile_present
643 {
644 let client = nodefw_client
645 .as_ref()
646 .expect("Expected NodeFWClient for blocklist delegation");
647 return delegate_policy_response(
648 resp,
649 policy_config,
650 client,
651 fw_config.destination_port,
652 metrics.clone(),
653 )
654 .await;
655 }
656 handle_policy_response(resp, policy_config, blocklists, metrics).await;
657 Ok(())
658}
659
660async fn handle_policy_response(
661 response: PolicyResponse,
662 policy_config: &PolicyConfig,
663 blocklists: Arc<Blocklists>,
664 metrics: Arc<TrafficControllerMetrics>,
665) {
666 let PolicyResponse {
667 block_client,
668 block_proxied_client,
669 } = response;
670 let PolicyConfig {
671 connection_blocklist_ttl_sec,
672 proxy_blocklist_ttl_sec,
673 ..
674 } = policy_config;
675 if let Some(client) = block_client
676 && blocklists
677 .clients
678 .insert(
679 client,
680 SystemTime::now() + Duration::from_secs(*connection_blocklist_ttl_sec),
681 )
682 .is_none()
683 {
684 debug!("Adding client {:?} to blocklist", client);
686 metrics.connection_ip_blocklist_len.inc();
687 }
688 if let Some(client) = block_proxied_client
689 && blocklists
690 .proxied_clients
691 .insert(
692 client,
693 SystemTime::now() + Duration::from_secs(*proxy_blocklist_ttl_sec),
694 )
695 .is_none()
696 {
697 debug!("Adding proxied client {:?} to blocklist", client);
699 metrics.proxy_ip_blocklist_len.inc();
700 }
701}
702
703async fn delegate_policy_response(
704 response: PolicyResponse,
705 policy_config: &PolicyConfig,
706 node_fw_client: &NodeFWClient,
707 destination_port: u16,
708 metrics: Arc<TrafficControllerMetrics>,
709) -> Result<(), reqwest::Error> {
710 let PolicyResponse {
711 block_client,
712 block_proxied_client,
713 } = response;
714 let PolicyConfig {
715 connection_blocklist_ttl_sec,
716 proxy_blocklist_ttl_sec,
717 ..
718 } = policy_config;
719 let mut addresses = vec![];
720 if let Some(client_id) = block_client {
721 debug!("Delegating client blocking to firewall");
722 addresses.push(BlockAddress {
723 source_address: client_id.to_string(),
724 destination_port,
725 ttl: *connection_blocklist_ttl_sec,
726 });
727 }
728 if let Some(ip) = block_proxied_client {
729 debug!("Delegating proxied client blocking to firewall");
730 addresses.push(BlockAddress {
731 source_address: ip.to_string(),
732 destination_port,
733 ttl: *proxy_blocklist_ttl_sec,
734 });
735 }
736 if addresses.is_empty() {
737 Ok(())
738 } else {
739 metrics
740 .blocks_delegated_to_firewall
741 .inc_by(addresses.len() as u64);
742 match node_fw_client
743 .block_addresses(BlockAddresses { addresses })
744 .await
745 {
746 Ok(()) => Ok(()),
747 Err(err) => {
748 metrics.firewall_delegation_request_fail.inc();
749 Err(err)
750 }
751 }
752 }
753}
754
755#[derive(Debug, Clone)]
756pub struct TrafficSimMetrics {
757 pub num_requests: u64,
758 pub num_blocked: u64,
759 pub time_to_first_block: Option<Duration>,
760 pub abs_time_to_first_block: Option<Duration>,
761 pub total_time_blocked: Duration,
762 pub num_blocklist_adds: u64,
763}
764
765impl Default for TrafficSimMetrics {
766 fn default() -> Self {
767 Self {
768 num_requests: 0,
769 num_blocked: 0,
770 time_to_first_block: None,
771 abs_time_to_first_block: None,
772 total_time_blocked: Duration::from_micros(0),
773 num_blocklist_adds: 0,
774 }
775 }
776}
777
778impl Add for TrafficSimMetrics {
779 type Output = Self;
780
781 fn add(self, other: Self) -> Self {
782 Self {
783 num_requests: self.num_requests + other.num_requests,
784 num_blocked: self.num_blocked + other.num_blocked,
785 time_to_first_block: match (self.time_to_first_block, other.time_to_first_block) {
786 (Some(a), Some(b)) => Some(a + b),
787 (Some(a), None) => Some(a),
788 (None, Some(b)) => Some(b),
789 (None, None) => None,
790 },
791 abs_time_to_first_block: match (
792 self.abs_time_to_first_block,
793 other.abs_time_to_first_block,
794 ) {
795 (Some(a), Some(b)) => Some(a.min(b)),
796 (Some(a), None) => Some(a),
797 (None, Some(b)) => Some(b),
798 (None, None) => None,
799 },
800 total_time_blocked: self.total_time_blocked + other.total_time_blocked,
801 num_blocklist_adds: self.num_blocklist_adds + other.num_blocklist_adds,
802 }
803 }
804}
805
806pub struct TrafficSim {
807 pub traffic_controller: TrafficController,
808}
809
810impl TrafficSim {
811 pub async fn run(
812 policy: PolicyConfig,
813 num_clients: u8,
814 per_client_tps: usize,
815 duration: Duration,
816 report: bool,
817 ) -> TrafficSimMetrics {
818 assert!(
819 per_client_tps <= 10_000,
820 "per_client_tps must be less than 10,000. For higher values, increase num_clients"
821 );
822 assert!(num_clients < 20, "num_clients must be greater than 0");
823 assert!(num_clients > 0);
824 assert!(per_client_tps > 0);
825 assert!(duration.as_secs() > 0);
826
827 let controller = TrafficController::init_for_test(policy.clone(), None).await;
828 let tasks = (0..num_clients).map(|task_num| {
829 tokio::spawn(Self::run_single_client(
830 controller.clone(),
831 duration,
832 task_num,
833 per_client_tps,
834 ))
835 });
836
837 let status_task = if report {
838 Some(tokio::spawn(async move {
839 println!(
840 "Running naive traffic simulation for {} seconds",
841 duration.as_secs()
842 );
843 println!("Policy: {:#?}", policy);
844 println!("Num clients: {}", num_clients);
845 println!("TPS per client: {}", per_client_tps);
846 println!(
847 "Target total TPS: {}",
848 per_client_tps * num_clients as usize
849 );
850 println!("\n");
851 for _ in 0..duration.as_secs() {
852 print!(".");
853 tokio::time::sleep(Duration::from_secs(1)).await;
854 }
855 println!();
856 }))
857 } else {
858 None
859 };
860
861 let metrics = futures::future::join_all(tasks).await.into_iter().fold(
862 TrafficSimMetrics::default(),
863 |acc, run_client_ret| {
864 if let Ok(metrics) = run_client_ret {
865 acc + metrics
866 } else {
867 error!(
868 "Error running traffic sim client: {:?}",
869 run_client_ret.err()
870 );
871 acc
872 }
873 },
874 );
875
876 if report {
877 status_task.unwrap().await.unwrap();
878 Self::report_metrics(metrics.clone(), duration, per_client_tps, num_clients);
879 }
880 metrics
881 }
882
883 async fn run_single_client(
884 controller: TrafficController,
885 duration: Duration,
886 task_num: u8,
887 per_client_tps: usize,
888 ) -> TrafficSimMetrics {
889 let sleep_time = Duration::from_micros(rand::thread_rng().gen_range(0..100));
893 tokio::time::sleep(sleep_time).await;
894
895 let mut num_requests = 0;
897 let mut num_blocked = 0;
898 let mut time_to_first_block = None;
899 let mut total_time_blocked = Duration::from_micros(0);
900 let mut num_blocklist_adds = 0;
901 let mut currently_blocked = false;
903 let mut time_blocked_start = Instant::now();
904 let start = Instant::now();
905
906 while start.elapsed() < duration {
907 let client = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, task_num)));
908 let allowed = controller.check(&client, &None).await;
909 if allowed {
910 if currently_blocked {
911 total_time_blocked += time_blocked_start.elapsed();
912 currently_blocked = false;
913 }
914 controller.tally(TrafficTally::new(
915 client,
916 None,
918 None,
920 Weight::one(),
921 ));
922 } else {
923 if !currently_blocked {
924 time_blocked_start = Instant::now();
925 currently_blocked = true;
926 num_blocklist_adds += 1;
927 if time_to_first_block.is_none() {
928 time_to_first_block = Some(start.elapsed());
929 }
930 }
931 num_blocked += 1;
932 }
933 num_requests += 1;
934 tokio::time::sleep(Duration::from_micros(1_000_000 / per_client_tps as u64)).await;
935 }
936 TrafficSimMetrics {
937 num_requests,
938 num_blocked,
939 time_to_first_block,
940 abs_time_to_first_block: time_to_first_block,
941 total_time_blocked,
942 num_blocklist_adds,
943 }
944 }
945
946 fn report_metrics(
947 metrics: TrafficSimMetrics,
948 duration: Duration,
949 per_client_tps: usize,
950 num_clients: u8,
951 ) {
952 println!("TrafficSim metrics:");
953 println!("-------------------");
954 println!(
956 "Num expected requests: {}",
957 per_client_tps * (num_clients as usize) * duration.as_secs() as usize
958 );
959 println!("Num actual requests: {}", metrics.num_requests);
960 println!("Num blocked requests: {}", metrics.num_blocked);
964 println!(
967 "Num times added to blocklist: {}",
968 metrics.num_blocklist_adds
969 );
970 let avg_first_block_time = metrics
973 .time_to_first_block
974 .map(|ttf| ttf / num_clients as u32);
975 println!("Average time to first block: {:?}", avg_first_block_time);
976 println!(
979 "Abolute time to first block (across all clients): {:?}",
980 metrics.abs_time_to_first_block
981 );
982 let avg_time_blocked = if metrics.num_blocklist_adds > 0 {
984 metrics.total_time_blocked.as_millis() as u64 / metrics.num_blocklist_adds
985 } else {
986 0
987 };
988 println!(
989 "Average time blocked (ttl): {:?}",
990 Duration::from_millis(avg_time_blocked)
991 );
992 }
993}
994
995pub fn parse_ip(ip: &str) -> Option<IpAddr> {
996 ip.parse::<IpAddr>().ok().or_else(|| {
997 ip.parse::<SocketAddr>()
998 .ok()
999 .map(|socket_addr| socket_addr.ip())
1000 .or_else(|| {
1001 error!("Failed to parse value of {:?} to ip address or socket.", ip,);
1002 None
1003 })
1004 })
1005}