1use std::{collections::HashMap, net::IpAddr, sync::Arc};
6
7use count_min_sketch::CountMinSketch32;
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::RwLock;
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, VecDeque};
12use std::fmt::Debug;
13use std::hash::Hash;
14use std::time::Duration;
15use std::time::{Instant, SystemTime};
16use sui_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
17use tracing::{info, trace};
18
19const HIGHEST_RATES_CAPACITY: usize = 20;
20
21#[derive(Hash, Eq, PartialEq, Debug)]
23enum ClientType {
24 Direct,
25 ThroughFullnode,
26}
27
28#[derive(Hash, Eq, PartialEq, Debug)]
29struct SketchKey {
30 salt: u64,
31 ip_addr: IpAddr,
32 client_type: ClientType,
33}
34
35struct HighestRates {
36 direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
37 proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
38 capacity: usize,
39}
40
41pub struct TrafficSketch {
42 sketches: VecDeque<CountMinSketch32<SketchKey>>,
52 window_size: Duration,
53 update_interval: Duration,
54 last_reset_time: Instant,
55 current_sketch_index: usize,
56 highest_rates: HighestRates,
65}
66
67impl TrafficSketch {
68 pub fn new(
69 window_size: Duration,
70 update_interval: Duration,
71 sketch_capacity: usize,
72 sketch_probability: f64,
73 sketch_tolerance: f64,
74 highest_rates_capacity: usize,
75 ) -> Self {
76 let num_sketches = window_size.as_secs() / update_interval.as_secs();
78 let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
79 if new_window_size != window_size {
80 info!(
81 "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
82 new_window_size.as_secs(),
83 update_interval.as_secs(),
84 );
85 }
86 let window_size = new_window_size;
87
88 assert!(
89 window_size < Duration::from_secs(600),
90 "window_size too large. Max 600 seconds"
91 );
92 assert!(
93 update_interval < window_size,
94 "Update interval may not be larger than window size"
95 );
96 assert!(
97 update_interval >= Duration::from_secs(1),
98 "Update interval too short, must be at least 1 second"
99 );
100 assert!(num_sketches <= 10, "Given parameters require too many sketches to be stored. Reduce window size or increase update interval.");
101 let mem_estimate = (num_sketches as usize)
102 * CountMinSketch32::<IpAddr>::estimate_memory(
103 sketch_capacity,
104 sketch_probability,
105 sketch_tolerance,
106 )
107 .expect("Failed to estimate memory for CountMinSketch32");
108 assert!(mem_estimate < 128_000_000, "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval.");
109
110 let mut sketches = VecDeque::with_capacity(num_sketches as usize);
111 for _ in 0..num_sketches {
112 sketches.push_back(
113 CountMinSketch32::<SketchKey>::new(
114 sketch_capacity,
115 sketch_probability,
116 sketch_tolerance,
117 )
118 .expect("Failed to create CountMinSketch32"),
119 );
120 }
121 Self {
122 sketches,
123 window_size,
124 update_interval,
125 last_reset_time: Instant::now(),
126 current_sketch_index: 0,
127 highest_rates: HighestRates {
128 direct: BinaryHeap::with_capacity(highest_rates_capacity),
129 proxied: BinaryHeap::with_capacity(highest_rates_capacity),
130 capacity: highest_rates_capacity,
131 },
132 }
133 }
134
135 fn increment_count(&mut self, key: &SketchKey) {
136 let current_time = Instant::now();
138 let mut elapsed = current_time.duration_since(self.last_reset_time);
139 while elapsed >= self.update_interval {
140 self.rotate_window();
141 elapsed -= self.update_interval;
142 }
143 self.sketches[self.current_sketch_index].increment(key);
145 }
146
147 fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
148 let count: u32 = self
149 .sketches
150 .iter()
151 .map(|sketch| sketch.estimate(key))
152 .sum();
153 let rate = count as f64 / self.window_size.as_secs() as f64;
154 self.update_highest_rates(key, rate);
155 rate
156 }
157
158 fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
159 match key.client_type {
160 ClientType::Direct => {
161 Self::update_highest_rate(
162 &mut self.highest_rates.direct,
163 key.ip_addr,
164 rate,
165 self.highest_rates.capacity,
166 );
167 }
168 ClientType::ThroughFullnode => {
169 Self::update_highest_rate(
170 &mut self.highest_rates.proxied,
171 key.ip_addr,
172 rate,
173 self.highest_rates.capacity,
174 );
175 }
176 }
177 }
178
179 fn update_highest_rate(
180 rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
181 ip_addr: IpAddr,
182 rate: f64,
183 capacity: usize,
184 ) {
185 rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
188
189 let rate = rate as u64;
190 if rate_heap.len() < capacity {
191 rate_heap.push(Reverse((rate, ip_addr)));
192 } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek() {
193 if rate > smallest_score {
194 rate_heap.pop();
195 rate_heap.push(Reverse((rate, ip_addr)));
196 }
197 }
198 }
199
200 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
201 self.highest_rates
202 .direct
203 .iter()
204 .map(|Reverse(v)| v)
205 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
206 .copied()
207 }
208
209 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
210 self.highest_rates
211 .proxied
212 .iter()
213 .map(|Reverse(v)| v)
214 .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
215 .copied()
216 }
217
218 fn rotate_window(&mut self) {
219 self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
220 self.sketches[self.current_sketch_index].clear();
221 self.last_reset_time = Instant::now();
222 }
223}
224
225#[derive(Clone, Debug)]
226pub struct TrafficTally {
227 pub direct: Option<IpAddr>,
228 pub through_fullnode: Option<IpAddr>,
229 pub error_info: Option<(Weight, String)>,
230 pub spam_weight: Weight,
231 pub timestamp: SystemTime,
232}
233
234impl TrafficTally {
235 pub fn new(
236 direct: Option<IpAddr>,
237 through_fullnode: Option<IpAddr>,
238 error_info: Option<(Weight, String)>,
239 spam_weight: Weight,
240 ) -> Self {
241 Self {
242 direct,
243 through_fullnode,
244 error_info,
245 spam_weight,
246 timestamp: SystemTime::now(),
247 }
248 }
249}
250
251#[derive(Clone, Debug, Default)]
252pub struct PolicyResponse {
253 pub block_client: Option<IpAddr>,
254 pub block_proxied_client: Option<IpAddr>,
255}
256
257pub trait Policy {
258 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
261 fn policy_config(&self) -> &PolicyConfig;
262}
263
264pub enum TrafficControlPolicy {
267 FreqThreshold(FreqThresholdPolicy),
268 NoOp(NoOpPolicy),
269 TestNConnIP(TestNConnIPPolicy),
271 TestPanicOnInvocation(TestPanicOnInvocationPolicy),
272}
273
274impl Policy for TrafficControlPolicy {
275 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
276 match self {
277 TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
278 TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
279 TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
280 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
281 }
282 }
283
284 fn policy_config(&self) -> &PolicyConfig {
285 match self {
286 TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
287 TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
288 TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
289 TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
290 }
291 }
292}
293
294impl TrafficControlPolicy {
295 pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
296 Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
297 }
298 pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
299 Self::from_config(policy_config.clone().error_policy_type, policy_config).await
300 }
301 pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
302 match policy_type {
303 PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
304 PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
305 FreqThresholdPolicy::new(policy_config, freq_threshold_config),
306 ),
307 PolicyType::TestNConnIP(n) => {
308 Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
309 }
310 PolicyType::TestPanicOnInvocation => {
311 Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
312 }
313 }
314 }
315}
316
317pub struct FreqThresholdPolicy {
320 pub config: PolicyConfig,
321 pub client_threshold: u64,
322 pub proxied_client_threshold: u64,
323 sketch: TrafficSketch,
324 salt: u64,
330}
331
332impl FreqThresholdPolicy {
333 pub fn new(
334 config: PolicyConfig,
335 FreqThresholdConfig {
336 client_threshold,
337 proxied_client_threshold,
338 window_size_secs,
339 update_interval_secs,
340 sketch_capacity,
341 sketch_probability,
342 sketch_tolerance,
343 }: FreqThresholdConfig,
344 ) -> Self {
345 let sketch = TrafficSketch::new(
346 Duration::from_secs(window_size_secs),
347 Duration::from_secs(update_interval_secs),
348 sketch_capacity,
349 sketch_probability,
350 sketch_tolerance,
351 HIGHEST_RATES_CAPACITY,
352 );
353 Self {
354 config,
355 sketch,
356 client_threshold,
357 proxied_client_threshold,
358 salt: rand::random(),
359 }
360 }
361
362 pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
363 self.sketch.highest_direct_rate()
364 }
365
366 pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
367 self.sketch.highest_proxied_rate()
368 }
369
370 pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
371 let block_client = if let Some(source) = tally.direct {
372 let key = SketchKey {
373 salt: self.salt,
374 ip_addr: source,
375 client_type: ClientType::Direct,
376 };
377 self.sketch.increment_count(&key);
378 let req_rate = self.sketch.get_request_rate(&key);
379 trace!(
380 "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
381 req_rate,
382 self.client_threshold,
383 source,
384 );
385 if req_rate >= self.client_threshold as f64 {
386 Some(source)
387 } else {
388 None
389 }
390 } else {
391 None
392 };
393 let block_proxied_client = if let Some(source) = tally.through_fullnode {
394 let key = SketchKey {
395 salt: self.salt,
396 ip_addr: source,
397 client_type: ClientType::ThroughFullnode,
398 };
399 self.sketch.increment_count(&key);
400 if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
401 Some(source)
402 } else {
403 None
404 }
405 } else {
406 None
407 };
408 PolicyResponse {
409 block_client,
410 block_proxied_client,
411 }
412 }
413
414 fn policy_config(&self) -> &PolicyConfig {
415 &self.config
416 }
417}
418
419#[derive(Clone)]
422pub struct NoOpPolicy {
423 config: PolicyConfig,
424}
425
426impl NoOpPolicy {
427 pub fn new(config: PolicyConfig) -> Self {
428 Self { config }
429 }
430
431 fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
432 PolicyResponse::default()
433 }
434
435 fn policy_config(&self) -> &PolicyConfig {
436 &self.config
437 }
438}
439
440#[derive(Clone)]
441pub struct TestNConnIPPolicy {
442 pub threshold: u64,
443 pub config: PolicyConfig,
444 frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
445}
446
447impl TestNConnIPPolicy {
448 pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
449 let frequencies = Arc::new(RwLock::new(HashMap::new()));
450 let frequencies_clone = frequencies.clone();
451 spawn_monitored_task!(run_clear_frequencies(
452 frequencies_clone,
453 config.connection_blocklist_ttl_sec * 2,
454 ));
455 Self {
456 config,
457 frequencies,
458 threshold,
459 }
460 }
461
462 fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
463 let client = if let Some(client) = tally.direct {
464 client
465 } else {
466 return PolicyResponse::default();
467 };
468
469 let mut frequencies = self.frequencies.write();
471 let count = frequencies.entry(client).or_insert(0);
472 *count += 1;
473 PolicyResponse {
474 block_client: if *count >= self.threshold {
475 Some(client)
476 } else {
477 None
478 },
479 block_proxied_client: None,
480 }
481 }
482
483 fn policy_config(&self) -> &PolicyConfig {
484 &self.config
485 }
486}
487
488async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
489 loop {
490 tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
491 frequencies.write().clear();
492 }
493}
494
495#[derive(Clone)]
496pub struct TestPanicOnInvocationPolicy {
497 config: PolicyConfig,
498}
499
500impl TestPanicOnInvocationPolicy {
501 pub fn new(config: PolicyConfig) -> Self {
502 Self { config }
503 }
504
505 fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
506 panic!("Tally for this policy should never be invoked")
507 }
508
509 fn policy_config(&self) -> &PolicyConfig {
510 &self.config
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use std::net::{IpAddr, Ipv4Addr};
518 use sui_macros::sim_test;
519 use sui_types::traffic_control::{
520 DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
521 };
522
523 #[sim_test]
524 async fn test_freq_threshold_policy() {
525 let mut policy = FreqThresholdPolicy::new(
529 PolicyConfig::default(),
530 FreqThresholdConfig {
531 client_threshold: 5,
532 proxied_client_threshold: 2,
533 window_size_secs: 5,
534 update_interval_secs: 1,
535 ..Default::default()
536 },
537 );
538 let alice = TrafficTally {
542 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
543 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
544 error_info: None,
545 spam_weight: Weight::one(),
546 timestamp: SystemTime::now(),
547 };
548 let bob = TrafficTally {
549 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
550 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
551 error_info: None,
552 spam_weight: Weight::one(),
553 timestamp: SystemTime::now(),
554 };
555 let charlie = TrafficTally {
556 direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
557 through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
558 error_info: None,
559 spam_weight: Weight::one(),
560 timestamp: SystemTime::now(),
561 };
562
563 for _ in 0..2 {
565 let response = policy.handle_tally(alice.clone());
566 assert_eq!(response.block_proxied_client, None);
567 assert_eq!(response.block_client, None);
568 }
569
570 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
571 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
572 assert_eq!(direct_ip_addr, alice.direct.unwrap());
573 assert!(direct_rate < 1);
574 assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
575 assert!(proxied_rate < 1);
576
577 for _ in 0..9 {
579 let response = policy.handle_tally(bob.clone());
580 assert_eq!(response.block_client, None);
581 assert_eq!(response.block_proxied_client, None);
582 }
583 let response = policy.handle_tally(bob.clone());
584 assert_eq!(response.block_client, None);
585 assert_eq!(response.block_proxied_client, bob.through_fullnode);
586
587 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
589 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
590 assert_eq!(direct_ip_addr, bob.direct.unwrap());
591 assert_eq!(direct_rate, 2);
592 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
593 assert_eq!(proxied_rate, 2);
594
595 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
599 for _ in 0..2 {
600 let response = policy.handle_tally(alice.clone());
601 assert_eq!(response.block_client, None);
602 assert_eq!(response.block_proxied_client, None);
603 }
604 let _ = policy.handle_tally(bob.clone());
607 let response = policy.handle_tally(bob.clone());
608 assert_eq!(response.block_client, None);
609 assert_eq!(response.block_proxied_client, bob.through_fullnode);
610
611 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
612 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
613 assert_eq!(direct_ip_addr, alice.direct.unwrap());
615 assert_eq!(direct_rate, 3);
616 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
619 assert_eq!(proxied_rate, 2);
620
621 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
623 for i in 0..5 {
624 let response = policy.handle_tally(alice.clone());
625 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
626 assert_eq!(response.block_proxied_client, None);
627 }
628
629 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
631 let response = policy.handle_tally(alice.clone());
632 assert_eq!(response.block_client, None);
633 assert_eq!(response.block_proxied_client, alice.through_fullnode);
634
635 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
636 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
637 assert_eq!(direct_ip_addr, alice.direct.unwrap());
638 assert_eq!(direct_rate, 4);
639 assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
640 assert_eq!(proxied_rate, 2);
641
642 for i in 0..2 {
644 let response = policy.handle_tally(charlie.clone());
645 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
646 assert_eq!(response.block_proxied_client, None);
647 }
648 let response = policy.handle_tally(charlie.clone());
650 assert_eq!(response.block_proxied_client, None);
651 assert_eq!(response.block_client, charlie.direct);
652
653 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
657 for i in 0..3 {
658 let response = policy.handle_tally(charlie.clone());
659 assert_eq!(response.block_client, None, "Blocked at i = {}", i);
660 assert_eq!(response.block_proxied_client, None);
661 }
662
663 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
665 let _ = policy.handle_tally(alice.clone());
668 let _ = policy.handle_tally(bob.clone());
669 let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
670 let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
671 assert_eq!(direct_ip_addr, alice.direct.unwrap());
672 assert_eq!(direct_rate, 0);
673 assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
674 assert_eq!(proxied_rate, 1);
675 }
676
677 #[sim_test]
678 async fn test_traffic_sketch_mem_estimate() {
679 let window_size = Duration::from_secs(30);
683 let update_interval = Duration::from_secs(5);
684 let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
685 DEFAULT_SKETCH_CAPACITY,
686 DEFAULT_SKETCH_PROBABILITY,
687 DEFAULT_SKETCH_TOLERANCE,
688 )
689 .unwrap()
690 * (window_size.as_secs() / update_interval.as_secs()) as usize;
691 assert!(
692 mem_estimate < 128_000_000,
693 "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
694 );
695 }
696}