1use std::{collections::HashMap, net::IpAddr, sync::Arc};
6
7use count_min_sketch::CountMinSketch32;
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::RwLock;
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, VecDeque};
12use std::fmt::Debug;
13use std::hash::Hash;
14use std::time::Duration;
15use std::time::{Instant, SystemTime};
16use sui_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
17use tracing::{info, trace};
18
19const HIGHEST_RATES_CAPACITY: usize = 20;
20
21#[derive(Hash, Eq, PartialEq, Debug)]
23enum ClientType {
24 Direct,
25 ThroughFullnode,
26}
27
28#[derive(Hash, Eq, PartialEq, Debug)]
29struct SketchKey {
30 salt: u64,
31 ip_addr: IpAddr,
32 client_type: ClientType,
33}
34
35struct HighestRates {
36 direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
37 proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
38 capacity: usize,
39}
40
41pub struct TrafficSketch {
42 sketches: VecDeque<CountMinSketch32<SketchKey>>,
52 window_size: Duration,
53 update_interval: Duration,
54 last_reset_time: Instant,
55 current_sketch_index: usize,
56 highest_rates: HighestRates,
65}
66
67impl TrafficSketch {
68 pub fn new(
69 window_size: Duration,
70 update_interval: Duration,
71 sketch_capacity: usize,
72 sketch_probability: f64,
73 sketch_tolerance: f64,
74 highest_rates_capacity: usize,
75 ) -> Self {
76 let num_sketches = window_size.as_secs() / update_interval.as_secs();
78 let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
79 if new_window_size != window_size {
80 info!(
81 "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
82 new_window_size.as_secs(),
83 update_interval.as_secs(),
84 );
85 }
86 let window_size = new_window_size;
87
88 assert!(
89 window_size < Duration::from_secs(600),
90 "window_size too large. Max 600 seconds"
91 );
92 assert!(
93 update_interval < window_size,
94 "Update interval may not be larger than window size"
95 );
96 assert!(
97 update_interval >= Duration::from_secs(1),
98 "Update interval too short, must be at least 1 second"
99 );
100 assert!(
101 num_sketches <= 10,
102 "Given parameters require too many sketches to be stored. Reduce window size or increase update interval."
103 );
104 let mem_estimate = (num_sketches as usize)
105 * CountMinSketch32::<IpAddr>::estimate_memory(
106 sketch_capacity,
107 sketch_probability,
108 sketch_tolerance,
109 )
110 .expect("Failed to estimate memory for CountMinSketch32");
111 assert!(
112 mem_estimate < 128_000_000,
113 "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval."
114 );
115
116 let mut sketches = VecDeque::with_capacity(num_sketches as usize);
117 for _ in 0..num_sketches {
118 sketches.push_back(
119 CountMinSketch32::<SketchKey>::new(
120 sketch_capacity,
121 sketch_probability,
122 sketch_tolerance,
123 )
124 .expect("Failed to create CountMinSketch32"),
125 );
126 }
127 Self {
128 sketches,
129 window_size,
130 update_interval,
131 last_reset_time: Instant::now(),
132 current_sketch_index: 0,
133 highest_rates: HighestRates {
134 direct: BinaryHeap::with_capacity(highest_rates_capacity),
135 proxied: BinaryHeap::with_capacity(highest_rates_capacity),
136 capacity: highest_rates_capacity,
137 },
138 }
139 }
140
141 fn increment_count(&mut self, key: &SketchKey) {
142 let current_time = Instant::now();
144 let mut elapsed = current_time.duration_since(self.last_reset_time);
145 while elapsed >= self.update_interval {
146 self.rotate_window();
147 elapsed -= self.update_interval;
148 }
149 self.sketches[self.current_sketch_index].increment(key);
151 }
152
153 fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
154 let count: u32 = self
155 .sketches
156 .iter()
157 .map(|sketch| sketch.estimate(key))
158 .sum();
159 let rate = count as f64 / self.window_size.as_secs() as f64;
160 self.update_highest_rates(key, rate);
161 rate
162 }
163
164 fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
165 match key.client_type {
166 ClientType::Direct => {
167 Self::update_highest_rate(
168 &mut self.highest_rates.direct,
169 key.ip_addr,
170 rate,
171 self.highest_rates.capacity,
172 );
173 }
174 ClientType::ThroughFullnode => {
175 Self::update_highest_rate(
176 &mut self.highest_rates.proxied,
177 key.ip_addr,
178 rate,
179 self.highest_rates.capacity,
180 );
181 }
182 }
183 }
184
185 fn update_highest_rate(
186 rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
187 ip_addr: IpAddr,
188 rate: f64,
189 capacity: usize,
190 ) {
191 rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
194
195 let rate = rate as u64;
196 if rate_heap.len() < capacity {
197 rate_heap.push(Reverse((rate, ip_addr)));
198 } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek()
199 && rate > smallest_score
200 {
201 rate_heap.pop();
202 rate_heap.push(Reverse((rate, ip_addr)));
203 }
204 }
205
206 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
207 self.highest_rates
208 .direct
209 .iter()
210 .map(|Reverse(v)| v)
211 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
212 .copied()
213 }
214
215 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
216 self.highest_rates
217 .proxied
218 .iter()
219 .map(|Reverse(v)| v)
220 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
221 .copied()
222 }
223
224 fn rotate_window(&mut self) {
225 self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
226 self.sketches[self.current_sketch_index].clear();
227 self.last_reset_time = Instant::now();
228 }
229}
230
231#[derive(Clone, Debug)]
232pub struct TrafficTally {
233 pub direct: Option<IpAddr>,
234 pub through_fullnode: Option<IpAddr>,
235 pub error_info: Option<(Weight, String)>,
236 pub spam_weight: Weight,
237 pub timestamp: SystemTime,
238 pub method: Option<String>,
239}
240
241impl TrafficTally {
242 pub fn new(
243 direct: Option<IpAddr>,
244 through_fullnode: Option<IpAddr>,
245 error_info: Option<(Weight, String)>,
246 spam_weight: Weight,
247 ) -> Self {
248 Self {
249 direct,
250 through_fullnode,
251 error_info,
252 spam_weight,
253 timestamp: SystemTime::now(),
254 method: None,
255 }
256 }
257
258 pub fn with_method(mut self, method: String) -> Self {
259 self.method = Some(method);
260 self
261 }
262}
263
264#[derive(Clone, Debug, Default)]
265pub struct PolicyResponse {
266 pub block_client: Option<IpAddr>,
267 pub block_proxied_client: Option<IpAddr>,
268}
269
270pub trait Policy {
271 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
274 fn policy_config(&self) -> &PolicyConfig;
275}
276
277pub enum TrafficControlPolicy {
280 FreqThreshold(FreqThresholdPolicy),
281 NoOp(NoOpPolicy),
282 TestNConnIP(TestNConnIPPolicy),
284 TestPanicOnInvocation(TestPanicOnInvocationPolicy),
285}
286
287impl Policy for TrafficControlPolicy {
288 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
289 match self {
290 TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
291 TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
292 TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
293 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
294 }
295 }
296
297 fn policy_config(&self) -> &PolicyConfig {
298 match self {
299 TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
300 TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
301 TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
302 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
303 }
304 }
305}
306
307impl TrafficControlPolicy {
308 pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
309 Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
310 }
311 pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
312 Self::from_config(policy_config.clone().error_policy_type, policy_config).await
313 }
314 pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
315 match policy_type {
316 PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
317 PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
318 FreqThresholdPolicy::new(policy_config, freq_threshold_config),
319 ),
320 PolicyType::TestNConnIP(n) => {
321 Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
322 }
323 PolicyType::TestPanicOnInvocation => {
324 Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
325 }
326 }
327 }
328}
329
330pub struct FreqThresholdPolicy {
333 pub config: PolicyConfig,
334 pub client_threshold: u64,
335 pub proxied_client_threshold: u64,
336 sketch: TrafficSketch,
337 salt: u64,
343}
344
345impl FreqThresholdPolicy {
346 pub fn new(
347 config: PolicyConfig,
348 FreqThresholdConfig {
349 client_threshold,
350 proxied_client_threshold,
351 window_size_secs,
352 update_interval_secs,
353 sketch_capacity,
354 sketch_probability,
355 sketch_tolerance,
356 }: FreqThresholdConfig,
357 ) -> Self {
358 let sketch = TrafficSketch::new(
359 Duration::from_secs(window_size_secs),
360 Duration::from_secs(update_interval_secs),
361 sketch_capacity,
362 sketch_probability,
363 sketch_tolerance,
364 HIGHEST_RATES_CAPACITY,
365 );
366 Self {
367 config,
368 sketch,
369 client_threshold,
370 proxied_client_threshold,
371 salt: rand::random(),
372 }
373 }
374
375 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
376 self.sketch.highest_direct_rate()
377 }
378
379 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
380 self.sketch.highest_proxied_rate()
381 }
382
383 pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
384 let block_client = if let Some(source) = tally.direct {
385 let key = SketchKey {
386 salt: self.salt,
387 ip_addr: source,
388 client_type: ClientType::Direct,
389 };
390 self.sketch.increment_count(&key);
391 let req_rate = self.sketch.get_request_rate(&key);
392 trace!(
393 "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
394 req_rate, self.client_threshold, source,
395 );
396 if req_rate >= self.client_threshold as f64 {
397 Some(source)
398 } else {
399 None
400 }
401 } else {
402 None
403 };
404 let block_proxied_client = if let Some(source) = tally.through_fullnode {
405 let key = SketchKey {
406 salt: self.salt,
407 ip_addr: source,
408 client_type: ClientType::ThroughFullnode,
409 };
410 self.sketch.increment_count(&key);
411 if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
412 Some(source)
413 } else {
414 None
415 }
416 } else {
417 None
418 };
419 PolicyResponse {
420 block_client,
421 block_proxied_client,
422 }
423 }
424
425 fn policy_config(&self) -> &PolicyConfig {
426 &self.config
427 }
428}
429
430#[derive(Clone)]
433pub struct NoOpPolicy {
434 config: PolicyConfig,
435}
436
437impl NoOpPolicy {
438 pub fn new(config: PolicyConfig) -> Self {
439 Self { config }
440 }
441
442 fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
443 PolicyResponse::default()
444 }
445
446 fn policy_config(&self) -> &PolicyConfig {
447 &self.config
448 }
449}
450
451#[derive(Clone)]
452pub struct TestNConnIPPolicy {
453 pub threshold: u64,
454 pub config: PolicyConfig,
455 frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
456}
457
458impl TestNConnIPPolicy {
459 pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
460 let frequencies = Arc::new(RwLock::new(HashMap::new()));
461 let frequencies_clone = frequencies.clone();
462 spawn_monitored_task!(run_clear_frequencies(
463 frequencies_clone,
464 config.connection_blocklist_ttl_sec * 2,
465 ));
466 Self {
467 config,
468 frequencies,
469 threshold,
470 }
471 }
472
473 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
474 let client = if let Some(client) = tally.direct {
475 client
476 } else {
477 return PolicyResponse::default();
478 };
479
480 let mut frequencies = self.frequencies.write();
482 let count = frequencies.entry(client).or_insert(0);
483 *count += 1;
484 PolicyResponse {
485 block_client: if *count >= self.threshold {
486 Some(client)
487 } else {
488 None
489 },
490 block_proxied_client: None,
491 }
492 }
493
494 fn policy_config(&self) -> &PolicyConfig {
495 &self.config
496 }
497}
498
499async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
500 loop {
501 tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
502 frequencies.write().clear();
503 }
504}
505
506#[derive(Clone)]
507pub struct TestPanicOnInvocationPolicy {
508 config: PolicyConfig,
509}
510
511impl TestPanicOnInvocationPolicy {
512 pub fn new(config: PolicyConfig) -> Self {
513 Self { config }
514 }
515
516 fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
517 panic!("Tally for this policy should never be invoked")
518 }
519
520 fn policy_config(&self) -> &PolicyConfig {
521 &self.config
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use std::net::{IpAddr, Ipv4Addr};
529 use sui_macros::sim_test;
530 use sui_types::traffic_control::{
531 DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
532 };
533
534 #[sim_test]
535 async fn test_freq_threshold_policy() {
536 let mut policy = FreqThresholdPolicy::new(
540 PolicyConfig::default(),
541 FreqThresholdConfig {
542 client_threshold: 5,
543 proxied_client_threshold: 2,
544 window_size_secs: 5,
545 update_interval_secs: 1,
546 ..Default::default()
547 },
548 );
549 let alice = TrafficTally {
553 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
554 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
555 error_info: None,
556 spam_weight: Weight::one(),
557 timestamp: SystemTime::now(),
558 method: None,
559 };
560 let bob = TrafficTally {
561 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
562 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
563 error_info: None,
564 spam_weight: Weight::one(),
565 timestamp: SystemTime::now(),
566 method: None,
567 };
568 let charlie = TrafficTally {
569 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
570 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
571 error_info: None,
572 spam_weight: Weight::one(),
573 timestamp: SystemTime::now(),
574 method: None,
575 };
576
577 for _ in 0..2 {
579 let response = policy.handle_tally(alice.clone());
580 assert_eq!(response.block_proxied_client, None);
581 assert_eq!(response.block_client, None);
582 }
583
584 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
585 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
586 assert_eq!(direct_ip_addr, alice.direct.unwrap());
587 assert!(direct_rate < 1);
588 assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
589 assert!(proxied_rate < 1);
590
591 for _ in 0..9 {
593 let response = policy.handle_tally(bob.clone());
594 assert_eq!(response.block_client, None);
595 assert_eq!(response.block_proxied_client, None);
596 }
597 let response = policy.handle_tally(bob.clone());
598 assert_eq!(response.block_client, None);
599 assert_eq!(response.block_proxied_client, bob.through_fullnode);
600
601 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
603 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
604 assert_eq!(direct_ip_addr, bob.direct.unwrap());
605 assert_eq!(direct_rate, 2);
606 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
607 assert_eq!(proxied_rate, 2);
608
609 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
613 for _ in 0..2 {
614 let response = policy.handle_tally(alice.clone());
615 assert_eq!(response.block_client, None);
616 assert_eq!(response.block_proxied_client, None);
617 }
618 let _ = policy.handle_tally(bob.clone());
621 let response = policy.handle_tally(bob.clone());
622 assert_eq!(response.block_client, None);
623 assert_eq!(response.block_proxied_client, bob.through_fullnode);
624
625 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
626 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
627 assert_eq!(direct_ip_addr, alice.direct.unwrap());
629 assert_eq!(direct_rate, 3);
630 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
633 assert_eq!(proxied_rate, 2);
634
635 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
637 for i in 0..5 {
638 let response = policy.handle_tally(alice.clone());
639 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
640 assert_eq!(response.block_proxied_client, None);
641 }
642
643 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
645 let response = policy.handle_tally(alice.clone());
646 assert_eq!(response.block_client, None);
647 assert_eq!(response.block_proxied_client, alice.through_fullnode);
648
649 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
650 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
651 assert_eq!(direct_ip_addr, alice.direct.unwrap());
652 assert_eq!(direct_rate, 4);
653 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
654 assert_eq!(proxied_rate, 2);
655
656 for i in 0..2 {
658 let response = policy.handle_tally(charlie.clone());
659 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
660 assert_eq!(response.block_proxied_client, None);
661 }
662 let response = policy.handle_tally(charlie.clone());
664 assert_eq!(response.block_proxied_client, None);
665 assert_eq!(response.block_client, charlie.direct);
666
667 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
671 for i in 0..3 {
672 let response = policy.handle_tally(charlie.clone());
673 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
674 assert_eq!(response.block_proxied_client, None);
675 }
676
677 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
679 let _ = policy.handle_tally(alice.clone());
682 let _ = policy.handle_tally(bob.clone());
683 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
684 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
685 assert_eq!(direct_ip_addr, alice.direct.unwrap());
686 assert_eq!(direct_rate, 0);
687 assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
688 assert_eq!(proxied_rate, 1);
689 }
690
691 #[sim_test]
692 async fn test_traffic_sketch_mem_estimate() {
693 let window_size = Duration::from_secs(30);
697 let update_interval = Duration::from_secs(5);
698 let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
699 DEFAULT_SKETCH_CAPACITY,
700 DEFAULT_SKETCH_PROBABILITY,
701 DEFAULT_SKETCH_TOLERANCE,
702 )
703 .unwrap()
704 * (window_size.as_secs() / update_interval.as_secs()) as usize;
705 assert!(
706 mem_estimate < 128_000_000,
707 "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
708 );
709 }
710}