1use std::{collections::HashMap, net::IpAddr, sync::Arc};
6
7use count_min_sketch::CountMinSketch32;
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::RwLock;
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, VecDeque};
12use std::fmt::Debug;
13use std::hash::Hash;
14use std::time::Duration;
15use std::time::{Instant, SystemTime};
16use sui_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
17use tracing::{info, trace};
18
19const HIGHEST_RATES_CAPACITY: usize = 20;
20
21#[derive(Hash, Eq, PartialEq, Debug)]
23enum ClientType {
24 Direct,
25 ThroughFullnode,
26}
27
28#[derive(Hash, Eq, PartialEq, Debug)]
29struct SketchKey {
30 salt: u64,
31 ip_addr: IpAddr,
32 client_type: ClientType,
33}
34
35struct HighestRates {
36 direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
37 proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
38 capacity: usize,
39}
40
41pub struct TrafficSketch {
42 sketches: VecDeque<CountMinSketch32<SketchKey>>,
52 window_size: Duration,
53 update_interval: Duration,
54 last_reset_time: Instant,
55 current_sketch_index: usize,
56 highest_rates: HighestRates,
65}
66
67impl TrafficSketch {
68 pub fn new(
69 window_size: Duration,
70 update_interval: Duration,
71 sketch_capacity: usize,
72 sketch_probability: f64,
73 sketch_tolerance: f64,
74 highest_rates_capacity: usize,
75 ) -> Self {
76 let num_sketches = window_size.as_secs() / update_interval.as_secs();
78 let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
79 if new_window_size != window_size {
80 info!(
81 "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
82 new_window_size.as_secs(),
83 update_interval.as_secs(),
84 );
85 }
86 let window_size = new_window_size;
87
88 assert!(
89 window_size < Duration::from_secs(600),
90 "window_size too large. Max 600 seconds"
91 );
92 assert!(
93 update_interval < window_size,
94 "Update interval may not be larger than window size"
95 );
96 assert!(
97 update_interval >= Duration::from_secs(1),
98 "Update interval too short, must be at least 1 second"
99 );
100 assert!(
101 num_sketches <= 10,
102 "Given parameters require too many sketches to be stored. Reduce window size or increase update interval."
103 );
104 let mem_estimate = (num_sketches as usize)
105 * CountMinSketch32::<IpAddr>::estimate_memory(
106 sketch_capacity,
107 sketch_probability,
108 sketch_tolerance,
109 )
110 .expect("Failed to estimate memory for CountMinSketch32");
111 assert!(
112 mem_estimate < 128_000_000,
113 "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval."
114 );
115
116 let mut sketches = VecDeque::with_capacity(num_sketches as usize);
117 for _ in 0..num_sketches {
118 sketches.push_back(
119 CountMinSketch32::<SketchKey>::new(
120 sketch_capacity,
121 sketch_probability,
122 sketch_tolerance,
123 )
124 .expect("Failed to create CountMinSketch32"),
125 );
126 }
127 Self {
128 sketches,
129 window_size,
130 update_interval,
131 last_reset_time: Instant::now(),
132 current_sketch_index: 0,
133 highest_rates: HighestRates {
134 direct: BinaryHeap::with_capacity(highest_rates_capacity),
135 proxied: BinaryHeap::with_capacity(highest_rates_capacity),
136 capacity: highest_rates_capacity,
137 },
138 }
139 }
140
141 fn increment_count(&mut self, key: &SketchKey) {
142 let current_time = Instant::now();
144 let mut elapsed = current_time.duration_since(self.last_reset_time);
145 while elapsed >= self.update_interval {
146 self.rotate_window();
147 elapsed -= self.update_interval;
148 }
149 self.sketches[self.current_sketch_index].increment(key);
151 }
152
153 fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
154 let count: u32 = self
155 .sketches
156 .iter()
157 .map(|sketch| sketch.estimate(key))
158 .sum();
159 let rate = count as f64 / self.window_size.as_secs() as f64;
160 self.update_highest_rates(key, rate);
161 rate
162 }
163
164 fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
165 match key.client_type {
166 ClientType::Direct => {
167 Self::update_highest_rate(
168 &mut self.highest_rates.direct,
169 key.ip_addr,
170 rate,
171 self.highest_rates.capacity,
172 );
173 }
174 ClientType::ThroughFullnode => {
175 Self::update_highest_rate(
176 &mut self.highest_rates.proxied,
177 key.ip_addr,
178 rate,
179 self.highest_rates.capacity,
180 );
181 }
182 }
183 }
184
185 fn update_highest_rate(
186 rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
187 ip_addr: IpAddr,
188 rate: f64,
189 capacity: usize,
190 ) {
191 rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
194
195 let rate = rate as u64;
196 if rate_heap.len() < capacity {
197 rate_heap.push(Reverse((rate, ip_addr)));
198 } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek()
199 && rate > smallest_score
200 {
201 rate_heap.pop();
202 rate_heap.push(Reverse((rate, ip_addr)));
203 }
204 }
205
206 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
207 self.highest_rates
208 .direct
209 .iter()
210 .map(|Reverse(v)| v)
211 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
212 .copied()
213 }
214
215 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
216 self.highest_rates
217 .proxied
218 .iter()
219 .map(|Reverse(v)| v)
220 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
221 .copied()
222 }
223
224 fn rotate_window(&mut self) {
225 self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
226 self.sketches[self.current_sketch_index].clear();
227 self.last_reset_time = Instant::now();
228 }
229}
230
231#[derive(Clone, Debug)]
232pub struct TrafficTally {
233 pub direct: Option<IpAddr>,
234 pub through_fullnode: Option<IpAddr>,
235 pub error_info: Option<(Weight, String)>,
236 pub spam_weight: Weight,
237 pub timestamp: SystemTime,
238}
239
240impl TrafficTally {
241 pub fn new(
242 direct: Option<IpAddr>,
243 through_fullnode: Option<IpAddr>,
244 error_info: Option<(Weight, String)>,
245 spam_weight: Weight,
246 ) -> Self {
247 Self {
248 direct,
249 through_fullnode,
250 error_info,
251 spam_weight,
252 timestamp: SystemTime::now(),
253 }
254 }
255}
256
257#[derive(Clone, Debug, Default)]
258pub struct PolicyResponse {
259 pub block_client: Option<IpAddr>,
260 pub block_proxied_client: Option<IpAddr>,
261}
262
263pub trait Policy {
264 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
267 fn policy_config(&self) -> &PolicyConfig;
268}
269
270pub enum TrafficControlPolicy {
273 FreqThreshold(FreqThresholdPolicy),
274 NoOp(NoOpPolicy),
275 TestNConnIP(TestNConnIPPolicy),
277 TestPanicOnInvocation(TestPanicOnInvocationPolicy),
278}
279
280impl Policy for TrafficControlPolicy {
281 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
282 match self {
283 TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
284 TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
285 TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
286 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
287 }
288 }
289
290 fn policy_config(&self) -> &PolicyConfig {
291 match self {
292 TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
293 TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
294 TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
295 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
296 }
297 }
298}
299
300impl TrafficControlPolicy {
301 pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
302 Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
303 }
304 pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
305 Self::from_config(policy_config.clone().error_policy_type, policy_config).await
306 }
307 pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
308 match policy_type {
309 PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
310 PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
311 FreqThresholdPolicy::new(policy_config, freq_threshold_config),
312 ),
313 PolicyType::TestNConnIP(n) => {
314 Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
315 }
316 PolicyType::TestPanicOnInvocation => {
317 Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
318 }
319 }
320 }
321}
322
323pub struct FreqThresholdPolicy {
326 pub config: PolicyConfig,
327 pub client_threshold: u64,
328 pub proxied_client_threshold: u64,
329 sketch: TrafficSketch,
330 salt: u64,
336}
337
338impl FreqThresholdPolicy {
339 pub fn new(
340 config: PolicyConfig,
341 FreqThresholdConfig {
342 client_threshold,
343 proxied_client_threshold,
344 window_size_secs,
345 update_interval_secs,
346 sketch_capacity,
347 sketch_probability,
348 sketch_tolerance,
349 }: FreqThresholdConfig,
350 ) -> Self {
351 let sketch = TrafficSketch::new(
352 Duration::from_secs(window_size_secs),
353 Duration::from_secs(update_interval_secs),
354 sketch_capacity,
355 sketch_probability,
356 sketch_tolerance,
357 HIGHEST_RATES_CAPACITY,
358 );
359 Self {
360 config,
361 sketch,
362 client_threshold,
363 proxied_client_threshold,
364 salt: rand::random(),
365 }
366 }
367
368 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
369 self.sketch.highest_direct_rate()
370 }
371
372 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
373 self.sketch.highest_proxied_rate()
374 }
375
376 pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
377 let block_client = if let Some(source) = tally.direct {
378 let key = SketchKey {
379 salt: self.salt,
380 ip_addr: source,
381 client_type: ClientType::Direct,
382 };
383 self.sketch.increment_count(&key);
384 let req_rate = self.sketch.get_request_rate(&key);
385 trace!(
386 "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
387 req_rate, self.client_threshold, source,
388 );
389 if req_rate >= self.client_threshold as f64 {
390 Some(source)
391 } else {
392 None
393 }
394 } else {
395 None
396 };
397 let block_proxied_client = if let Some(source) = tally.through_fullnode {
398 let key = SketchKey {
399 salt: self.salt,
400 ip_addr: source,
401 client_type: ClientType::ThroughFullnode,
402 };
403 self.sketch.increment_count(&key);
404 if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
405 Some(source)
406 } else {
407 None
408 }
409 } else {
410 None
411 };
412 PolicyResponse {
413 block_client,
414 block_proxied_client,
415 }
416 }
417
418 fn policy_config(&self) -> &PolicyConfig {
419 &self.config
420 }
421}
422
423#[derive(Clone)]
426pub struct NoOpPolicy {
427 config: PolicyConfig,
428}
429
430impl NoOpPolicy {
431 pub fn new(config: PolicyConfig) -> Self {
432 Self { config }
433 }
434
435 fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
436 PolicyResponse::default()
437 }
438
439 fn policy_config(&self) -> &PolicyConfig {
440 &self.config
441 }
442}
443
444#[derive(Clone)]
445pub struct TestNConnIPPolicy {
446 pub threshold: u64,
447 pub config: PolicyConfig,
448 frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
449}
450
451impl TestNConnIPPolicy {
452 pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
453 let frequencies = Arc::new(RwLock::new(HashMap::new()));
454 let frequencies_clone = frequencies.clone();
455 spawn_monitored_task!(run_clear_frequencies(
456 frequencies_clone,
457 config.connection_blocklist_ttl_sec * 2,
458 ));
459 Self {
460 config,
461 frequencies,
462 threshold,
463 }
464 }
465
466 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
467 let client = if let Some(client) = tally.direct {
468 client
469 } else {
470 return PolicyResponse::default();
471 };
472
473 let mut frequencies = self.frequencies.write();
475 let count = frequencies.entry(client).or_insert(0);
476 *count += 1;
477 PolicyResponse {
478 block_client: if *count >= self.threshold {
479 Some(client)
480 } else {
481 None
482 },
483 block_proxied_client: None,
484 }
485 }
486
487 fn policy_config(&self) -> &PolicyConfig {
488 &self.config
489 }
490}
491
492async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
493 loop {
494 tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
495 frequencies.write().clear();
496 }
497}
498
499#[derive(Clone)]
500pub struct TestPanicOnInvocationPolicy {
501 config: PolicyConfig,
502}
503
504impl TestPanicOnInvocationPolicy {
505 pub fn new(config: PolicyConfig) -> Self {
506 Self { config }
507 }
508
509 fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
510 panic!("Tally for this policy should never be invoked")
511 }
512
513 fn policy_config(&self) -> &PolicyConfig {
514 &self.config
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use std::net::{IpAddr, Ipv4Addr};
522 use sui_macros::sim_test;
523 use sui_types::traffic_control::{
524 DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
525 };
526
527 #[sim_test]
528 async fn test_freq_threshold_policy() {
529 let mut policy = FreqThresholdPolicy::new(
533 PolicyConfig::default(),
534 FreqThresholdConfig {
535 client_threshold: 5,
536 proxied_client_threshold: 2,
537 window_size_secs: 5,
538 update_interval_secs: 1,
539 ..Default::default()
540 },
541 );
542 let alice = TrafficTally {
546 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
547 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
548 error_info: None,
549 spam_weight: Weight::one(),
550 timestamp: SystemTime::now(),
551 };
552 let bob = TrafficTally {
553 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
554 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
555 error_info: None,
556 spam_weight: Weight::one(),
557 timestamp: SystemTime::now(),
558 };
559 let charlie = TrafficTally {
560 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
561 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
562 error_info: None,
563 spam_weight: Weight::one(),
564 timestamp: SystemTime::now(),
565 };
566
567 for _ in 0..2 {
569 let response = policy.handle_tally(alice.clone());
570 assert_eq!(response.block_proxied_client, None);
571 assert_eq!(response.block_client, None);
572 }
573
574 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
575 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
576 assert_eq!(direct_ip_addr, alice.direct.unwrap());
577 assert!(direct_rate < 1);
578 assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
579 assert!(proxied_rate < 1);
580
581 for _ in 0..9 {
583 let response = policy.handle_tally(bob.clone());
584 assert_eq!(response.block_client, None);
585 assert_eq!(response.block_proxied_client, None);
586 }
587 let response = policy.handle_tally(bob.clone());
588 assert_eq!(response.block_client, None);
589 assert_eq!(response.block_proxied_client, bob.through_fullnode);
590
591 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
593 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
594 assert_eq!(direct_ip_addr, bob.direct.unwrap());
595 assert_eq!(direct_rate, 2);
596 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
597 assert_eq!(proxied_rate, 2);
598
599 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
603 for _ in 0..2 {
604 let response = policy.handle_tally(alice.clone());
605 assert_eq!(response.block_client, None);
606 assert_eq!(response.block_proxied_client, None);
607 }
608 let _ = policy.handle_tally(bob.clone());
611 let response = policy.handle_tally(bob.clone());
612 assert_eq!(response.block_client, None);
613 assert_eq!(response.block_proxied_client, bob.through_fullnode);
614
615 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
616 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
617 assert_eq!(direct_ip_addr, alice.direct.unwrap());
619 assert_eq!(direct_rate, 3);
620 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
623 assert_eq!(proxied_rate, 2);
624
625 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
627 for i in 0..5 {
628 let response = policy.handle_tally(alice.clone());
629 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
630 assert_eq!(response.block_proxied_client, None);
631 }
632
633 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
635 let response = policy.handle_tally(alice.clone());
636 assert_eq!(response.block_client, None);
637 assert_eq!(response.block_proxied_client, alice.through_fullnode);
638
639 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
640 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
641 assert_eq!(direct_ip_addr, alice.direct.unwrap());
642 assert_eq!(direct_rate, 4);
643 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
644 assert_eq!(proxied_rate, 2);
645
646 for i in 0..2 {
648 let response = policy.handle_tally(charlie.clone());
649 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
650 assert_eq!(response.block_proxied_client, None);
651 }
652 let response = policy.handle_tally(charlie.clone());
654 assert_eq!(response.block_proxied_client, None);
655 assert_eq!(response.block_client, charlie.direct);
656
657 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
661 for i in 0..3 {
662 let response = policy.handle_tally(charlie.clone());
663 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
664 assert_eq!(response.block_proxied_client, None);
665 }
666
667 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
669 let _ = policy.handle_tally(alice.clone());
672 let _ = policy.handle_tally(bob.clone());
673 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
674 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
675 assert_eq!(direct_ip_addr, alice.direct.unwrap());
676 assert_eq!(direct_rate, 0);
677 assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
678 assert_eq!(proxied_rate, 1);
679 }
680
681 #[sim_test]
682 async fn test_traffic_sketch_mem_estimate() {
683 let window_size = Duration::from_secs(30);
687 let update_interval = Duration::from_secs(5);
688 let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
689 DEFAULT_SKETCH_CAPACITY,
690 DEFAULT_SKETCH_PROBABILITY,
691 DEFAULT_SKETCH_TOLERANCE,
692 )
693 .unwrap()
694 * (window_size.as_secs() / update_interval.as_secs()) as usize;
695 assert!(
696 mem_estimate < 128_000_000,
697 "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
698 );
699 }
700}