sui_core/traffic_controller/
policies.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) Mysten Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, net::IpAddr, sync::Arc};
6
7use count_min_sketch::CountMinSketch32;
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::RwLock;
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, VecDeque};
12use std::fmt::Debug;
13use std::hash::Hash;
14use std::time::Duration;
15use std::time::{Instant, SystemTime};
16use sui_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
17use tracing::{info, trace};
18
19const HIGHEST_RATES_CAPACITY: usize = 20;
20
21/// The type of request client.
22#[derive(Hash, Eq, PartialEq, Debug)]
23enum ClientType {
24    Direct,
25    ThroughFullnode,
26}
27
28#[derive(Hash, Eq, PartialEq, Debug)]
29struct SketchKey {
30    salt: u64,
31    ip_addr: IpAddr,
32    client_type: ClientType,
33}
34
35struct HighestRates {
36    direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
37    proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
38    capacity: usize,
39}
40
41pub struct TrafficSketch {
42    /// Circular buffer Count Min Sketches representing a sliding window
43    /// of traffic data. Note that the 32 in CountMinSketch32 represents
44    /// the number of bits used to represent the count in the sketch. Since
45    /// we only count on a sketch for a window of `update_interval`, we only
46    /// need enough precision to represent the max expected unique IP addresses
47    /// we may see in that window. For a 10 second period, we might conservatively
48    /// expect 100,000, which can be represented in 17 bits, but not 16. We can
49    /// potentially lower the memory consumption by using CountMinSketch16, which
50    /// will reliably support up to ~65,000 unique IP addresses in the window.
51    sketches: VecDeque<CountMinSketch32<SketchKey>>,
52    window_size: Duration,
53    update_interval: Duration,
54    last_reset_time: Instant,
55    current_sketch_index: usize,
56    /// Used for metrics collection and logging purposes,
57    /// as CountMinSketch does not provide this directly.
58    /// Note that this is an imperfect metric, since we preserve
59    /// the highest N rates (by unique IP) that we have seen,
60    /// but update rates (down or up) as they change so that
61    /// the metric is not monotonic and reflects recent traffic.
62    /// However, this should only lead to inaccuracy edge cases
63    /// with very low traffic.
64    highest_rates: HighestRates,
65}
66
67impl TrafficSketch {
68    pub fn new(
69        window_size: Duration,
70        update_interval: Duration,
71        sketch_capacity: usize,
72        sketch_probability: f64,
73        sketch_tolerance: f64,
74        highest_rates_capacity: usize,
75    ) -> Self {
76        // intentionally round down via integer division. We can't have a partial sketch
77        let num_sketches = window_size.as_secs() / update_interval.as_secs();
78        let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
79        if new_window_size != window_size {
80            info!(
81                "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
82                new_window_size.as_secs(),
83                update_interval.as_secs(),
84            );
85        }
86        let window_size = new_window_size;
87
88        assert!(
89            window_size < Duration::from_secs(600),
90            "window_size too large. Max 600 seconds"
91        );
92        assert!(
93            update_interval < window_size,
94            "Update interval may not be larger than window size"
95        );
96        assert!(
97            update_interval >= Duration::from_secs(1),
98            "Update interval too short, must be at least 1 second"
99        );
100        assert!(
101            num_sketches <= 10,
102            "Given parameters require too many sketches to be stored. Reduce window size or increase update interval."
103        );
104        let mem_estimate = (num_sketches as usize)
105            * CountMinSketch32::<IpAddr>::estimate_memory(
106                sketch_capacity,
107                sketch_probability,
108                sketch_tolerance,
109            )
110            .expect("Failed to estimate memory for CountMinSketch32");
111        assert!(
112            mem_estimate < 128_000_000,
113            "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval."
114        );
115
116        let mut sketches = VecDeque::with_capacity(num_sketches as usize);
117        for _ in 0..num_sketches {
118            sketches.push_back(
119                CountMinSketch32::<SketchKey>::new(
120                    sketch_capacity,
121                    sketch_probability,
122                    sketch_tolerance,
123                )
124                .expect("Failed to create CountMinSketch32"),
125            );
126        }
127        Self {
128            sketches,
129            window_size,
130            update_interval,
131            last_reset_time: Instant::now(),
132            current_sketch_index: 0,
133            highest_rates: HighestRates {
134                direct: BinaryHeap::with_capacity(highest_rates_capacity),
135                proxied: BinaryHeap::with_capacity(highest_rates_capacity),
136                capacity: highest_rates_capacity,
137            },
138        }
139    }
140
141    fn increment_count(&mut self, key: &SketchKey) {
142        // reset all expired intervals
143        let current_time = Instant::now();
144        let mut elapsed = current_time.duration_since(self.last_reset_time);
145        while elapsed >= self.update_interval {
146            self.rotate_window();
147            elapsed -= self.update_interval;
148        }
149        // Increment in the current active sketch
150        self.sketches[self.current_sketch_index].increment(key);
151    }
152
153    fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
154        let count: u32 = self
155            .sketches
156            .iter()
157            .map(|sketch| sketch.estimate(key))
158            .sum();
159        let rate = count as f64 / self.window_size.as_secs() as f64;
160        self.update_highest_rates(key, rate);
161        rate
162    }
163
164    fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
165        match key.client_type {
166            ClientType::Direct => {
167                Self::update_highest_rate(
168                    &mut self.highest_rates.direct,
169                    key.ip_addr,
170                    rate,
171                    self.highest_rates.capacity,
172                );
173            }
174            ClientType::ThroughFullnode => {
175                Self::update_highest_rate(
176                    &mut self.highest_rates.proxied,
177                    key.ip_addr,
178                    rate,
179                    self.highest_rates.capacity,
180                );
181            }
182        }
183    }
184
185    fn update_highest_rate(
186        rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
187        ip_addr: IpAddr,
188        rate: f64,
189        capacity: usize,
190    ) {
191        // Remove previous instance of this IPAddr so that we
192        // can update with new rate
193        rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
194
195        let rate = rate as u64;
196        if rate_heap.len() < capacity {
197            rate_heap.push(Reverse((rate, ip_addr)));
198        } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek()
199            && rate > smallest_score
200        {
201            rate_heap.pop();
202            rate_heap.push(Reverse((rate, ip_addr)));
203        }
204    }
205
206    pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
207        self.highest_rates
208            .direct
209            .iter()
210            .map(|Reverse(v)| v)
211            .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
212            .copied()
213    }
214
215    pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
216        self.highest_rates
217            .proxied
218            .iter()
219            .map(|Reverse(v)| v)
220            .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
221            .copied()
222    }
223
224    fn rotate_window(&mut self) {
225        self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
226        self.sketches[self.current_sketch_index].clear();
227        self.last_reset_time = Instant::now();
228    }
229}
230
231#[derive(Clone, Debug)]
232pub struct TrafficTally {
233    pub direct: Option<IpAddr>,
234    pub through_fullnode: Option<IpAddr>,
235    pub error_info: Option<(Weight, String)>,
236    pub spam_weight: Weight,
237    pub timestamp: SystemTime,
238    pub method: Option<String>,
239}
240
241impl TrafficTally {
242    pub fn new(
243        direct: Option<IpAddr>,
244        through_fullnode: Option<IpAddr>,
245        error_info: Option<(Weight, String)>,
246        spam_weight: Weight,
247    ) -> Self {
248        Self {
249            direct,
250            through_fullnode,
251            error_info,
252            spam_weight,
253            timestamp: SystemTime::now(),
254            method: None,
255        }
256    }
257
258    pub fn with_method(mut self, method: String) -> Self {
259        self.method = Some(method);
260        self
261    }
262}
263
264#[derive(Clone, Debug, Default)]
265pub struct PolicyResponse {
266    pub block_client: Option<IpAddr>,
267    pub block_proxied_client: Option<IpAddr>,
268}
269
270pub trait Policy {
271    // returns, e.g. (true, false) if connection_ip should be added to blocklist
272    // and proxy_ip should not
273    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
274    fn policy_config(&self) -> &PolicyConfig;
275}
276
277// Nonserializable representation, also note that inner types are
278// not object safe, so we can't use a trait object instead
279pub enum TrafficControlPolicy {
280    FreqThreshold(FreqThresholdPolicy),
281    NoOp(NoOpPolicy),
282    // Test policies below this point
283    TestNConnIP(TestNConnIPPolicy),
284    TestPanicOnInvocation(TestPanicOnInvocationPolicy),
285}
286
287impl Policy for TrafficControlPolicy {
288    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
289        match self {
290            TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
291            TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
292            TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
293            TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
294        }
295    }
296
297    fn policy_config(&self) -> &PolicyConfig {
298        match self {
299            TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
300            TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
301            TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
302            TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
303        }
304    }
305}
306
307impl TrafficControlPolicy {
308    pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
309        Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
310    }
311    pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
312        Self::from_config(policy_config.clone().error_policy_type, policy_config).await
313    }
314    pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
315        match policy_type {
316            PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
317            PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
318                FreqThresholdPolicy::new(policy_config, freq_threshold_config),
319            ),
320            PolicyType::TestNConnIP(n) => {
321                Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
322            }
323            PolicyType::TestPanicOnInvocation => {
324                Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
325            }
326        }
327    }
328}
329
330////////////// *** Policy definitions *** //////////////
331
332pub struct FreqThresholdPolicy {
333    pub config: PolicyConfig,
334    pub client_threshold: u64,
335    pub proxied_client_threshold: u64,
336    sketch: TrafficSketch,
337    /// Unique salt to be added to all keys in the sketch. This
338    /// ensures that false positives are not correlated across
339    /// all nodes at the same time. For Sui validators for example,
340    /// this means that false positives should not prevent the network
341    /// from achieving quorum.
342    salt: u64,
343}
344
345impl FreqThresholdPolicy {
346    pub fn new(
347        config: PolicyConfig,
348        FreqThresholdConfig {
349            client_threshold,
350            proxied_client_threshold,
351            window_size_secs,
352            update_interval_secs,
353            sketch_capacity,
354            sketch_probability,
355            sketch_tolerance,
356        }: FreqThresholdConfig,
357    ) -> Self {
358        let sketch = TrafficSketch::new(
359            Duration::from_secs(window_size_secs),
360            Duration::from_secs(update_interval_secs),
361            sketch_capacity,
362            sketch_probability,
363            sketch_tolerance,
364            HIGHEST_RATES_CAPACITY,
365        );
366        Self {
367            config,
368            sketch,
369            client_threshold,
370            proxied_client_threshold,
371            salt: rand::random(),
372        }
373    }
374
375    pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
376        self.sketch.highest_direct_rate()
377    }
378
379    pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
380        self.sketch.highest_proxied_rate()
381    }
382
383    pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
384        let block_client = if let Some(source) = tally.direct {
385            let key = SketchKey {
386                salt: self.salt,
387                ip_addr: source,
388                client_type: ClientType::Direct,
389            };
390            self.sketch.increment_count(&key);
391            let req_rate = self.sketch.get_request_rate(&key);
392            trace!(
393                "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
394                req_rate, self.client_threshold, source,
395            );
396            if req_rate >= self.client_threshold as f64 {
397                Some(source)
398            } else {
399                None
400            }
401        } else {
402            None
403        };
404        let block_proxied_client = if let Some(source) = tally.through_fullnode {
405            let key = SketchKey {
406                salt: self.salt,
407                ip_addr: source,
408                client_type: ClientType::ThroughFullnode,
409            };
410            self.sketch.increment_count(&key);
411            if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
412                Some(source)
413            } else {
414                None
415            }
416        } else {
417            None
418        };
419        PolicyResponse {
420            block_client,
421            block_proxied_client,
422        }
423    }
424
425    fn policy_config(&self) -> &PolicyConfig {
426        &self.config
427    }
428}
429
430////////////// *** Test policies below this point *** //////////////
431
432#[derive(Clone)]
433pub struct NoOpPolicy {
434    config: PolicyConfig,
435}
436
437impl NoOpPolicy {
438    pub fn new(config: PolicyConfig) -> Self {
439        Self { config }
440    }
441
442    fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
443        PolicyResponse::default()
444    }
445
446    fn policy_config(&self) -> &PolicyConfig {
447        &self.config
448    }
449}
450
451#[derive(Clone)]
452pub struct TestNConnIPPolicy {
453    pub threshold: u64,
454    pub config: PolicyConfig,
455    frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
456}
457
458impl TestNConnIPPolicy {
459    pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
460        let frequencies = Arc::new(RwLock::new(HashMap::new()));
461        let frequencies_clone = frequencies.clone();
462        spawn_monitored_task!(run_clear_frequencies(
463            frequencies_clone,
464            config.connection_blocklist_ttl_sec * 2,
465        ));
466        Self {
467            config,
468            frequencies,
469            threshold,
470        }
471    }
472
473    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
474        let client = if let Some(client) = tally.direct {
475            client
476        } else {
477            return PolicyResponse::default();
478        };
479
480        // increment the count for the IP
481        let mut frequencies = self.frequencies.write();
482        let count = frequencies.entry(client).or_insert(0);
483        *count += 1;
484        PolicyResponse {
485            block_client: if *count >= self.threshold {
486                Some(client)
487            } else {
488                None
489            },
490            block_proxied_client: None,
491        }
492    }
493
494    fn policy_config(&self) -> &PolicyConfig {
495        &self.config
496    }
497}
498
499async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
500    loop {
501        tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
502        frequencies.write().clear();
503    }
504}
505
506#[derive(Clone)]
507pub struct TestPanicOnInvocationPolicy {
508    config: PolicyConfig,
509}
510
511impl TestPanicOnInvocationPolicy {
512    pub fn new(config: PolicyConfig) -> Self {
513        Self { config }
514    }
515
516    fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
517        panic!("Tally for this policy should never be invoked")
518    }
519
520    fn policy_config(&self) -> &PolicyConfig {
521        &self.config
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use std::net::{IpAddr, Ipv4Addr};
529    use sui_macros::sim_test;
530    use sui_types::traffic_control::{
531        DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
532    };
533
534    #[sim_test]
535    async fn test_freq_threshold_policy() {
536        // Create freq policy that will block on average frequency 2 requests per second
537        // for proxied connections and 4 requests per second for direct connections
538        // as observed over a 5 second window.
539        let mut policy = FreqThresholdPolicy::new(
540            PolicyConfig::default(),
541            FreqThresholdConfig {
542                client_threshold: 5,
543                proxied_client_threshold: 2,
544                window_size_secs: 5,
545                update_interval_secs: 1,
546                ..Default::default()
547            },
548        );
549        // alice and bob connection from different IPs through the
550        // same fullnode, thus have the same connection IP on
551        // validator, but different proxy IPs
552        let alice = TrafficTally {
553            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
554            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
555            error_info: None,
556            spam_weight: Weight::one(),
557            timestamp: SystemTime::now(),
558            method: None,
559        };
560        let bob = TrafficTally {
561            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
562            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
563            error_info: None,
564            spam_weight: Weight::one(),
565            timestamp: SystemTime::now(),
566            method: None,
567        };
568        let charlie = TrafficTally {
569            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
570            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
571            error_info: None,
572            spam_weight: Weight::one(),
573            timestamp: SystemTime::now(),
574            method: None,
575        };
576
577        // initial 2 tallies for alice, should not block
578        for _ in 0..2 {
579            let response = policy.handle_tally(alice.clone());
580            assert_eq!(response.block_proxied_client, None);
581            assert_eq!(response.block_client, None);
582        }
583
584        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
585        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
586        assert_eq!(direct_ip_addr, alice.direct.unwrap());
587        assert!(direct_rate < 1);
588        assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
589        assert!(proxied_rate < 1);
590
591        // meanwhile bob spams 10 requests at once and is blocked
592        for _ in 0..9 {
593            let response = policy.handle_tally(bob.clone());
594            assert_eq!(response.block_client, None);
595            assert_eq!(response.block_proxied_client, None);
596        }
597        let response = policy.handle_tally(bob.clone());
598        assert_eq!(response.block_client, None);
599        assert_eq!(response.block_proxied_client, bob.through_fullnode);
600
601        // highest rates should now show bob
602        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
603        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
604        assert_eq!(direct_ip_addr, bob.direct.unwrap());
605        assert_eq!(direct_rate, 2);
606        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
607        assert_eq!(proxied_rate, 2);
608
609        // 2 more tallies, so far we are above 2 tallies
610        // per second, but over the average window of 5 seconds
611        // we are still below the threshold. Should not block
612        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
613        for _ in 0..2 {
614            let response = policy.handle_tally(alice.clone());
615            assert_eq!(response.block_client, None);
616            assert_eq!(response.block_proxied_client, None);
617        }
618        // bob is no longer blocked, as we moved past the bursty traffic
619        // in the sliding window
620        let _ = policy.handle_tally(bob.clone());
621        let response = policy.handle_tally(bob.clone());
622        assert_eq!(response.block_client, None);
623        assert_eq!(response.block_proxied_client, bob.through_fullnode);
624
625        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
626        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
627        // direct rate increased due to alice going through same fullnode
628        assert_eq!(direct_ip_addr, alice.direct.unwrap());
629        assert_eq!(direct_rate, 3);
630        // highest rate should now have been updated given that Bob's rate
631        // recently decreased
632        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
633        assert_eq!(proxied_rate, 2);
634
635        // close to threshold for alice, but still below
636        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
637        for i in 0..5 {
638            let response = policy.handle_tally(alice.clone());
639            assert_eq!(response.block_client, None, "Blocked at i = {}", i);
640            assert_eq!(response.block_proxied_client, None);
641        }
642
643        // should block alice now
644        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
645        let response = policy.handle_tally(alice.clone());
646        assert_eq!(response.block_client, None);
647        assert_eq!(response.block_proxied_client, alice.through_fullnode);
648
649        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
650        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
651        assert_eq!(direct_ip_addr, alice.direct.unwrap());
652        assert_eq!(direct_rate, 4);
653        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
654        assert_eq!(proxied_rate, 2);
655
656        // spam through charlie to block connection
657        for i in 0..2 {
658            let response = policy.handle_tally(charlie.clone());
659            assert_eq!(response.block_client, None, "Blocked at i = {}", i);
660            assert_eq!(response.block_proxied_client, None);
661        }
662        // Now we block connection IP
663        let response = policy.handle_tally(charlie.clone());
664        assert_eq!(response.block_proxied_client, None);
665        assert_eq!(response.block_client, charlie.direct);
666
667        // Ensure that if we wait another second, we are no longer blocked
668        // as the bursty first second has finally rotated out of the sliding
669        // window
670        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
671        for i in 0..3 {
672            let response = policy.handle_tally(charlie.clone());
673            assert_eq!(response.block_client, None, "Blocked at i = {}", i);
674            assert_eq!(response.block_proxied_client, None);
675        }
676
677        // check that we revert back to previous highest rates when rates decrease
678        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
679        // alice and bob rates are now decreased after 5 seconds of break, but charlie's
680        // has not since they have not yet sent a new request
681        let _ = policy.handle_tally(alice.clone());
682        let _ = policy.handle_tally(bob.clone());
683        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
684        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
685        assert_eq!(direct_ip_addr, alice.direct.unwrap());
686        assert_eq!(direct_rate, 0);
687        assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
688        assert_eq!(proxied_rate, 1);
689    }
690
691    #[sim_test]
692    async fn test_traffic_sketch_mem_estimate() {
693        // Test for getting a rough estimate of memory usage for the traffic sketch
694        // given certain parameters. Parameters below are the default.
695        // With default parameters, memory estimate is 113 MB.
696        let window_size = Duration::from_secs(30);
697        let update_interval = Duration::from_secs(5);
698        let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
699            DEFAULT_SKETCH_CAPACITY,
700            DEFAULT_SKETCH_PROBABILITY,
701            DEFAULT_SKETCH_TOLERANCE,
702        )
703        .unwrap()
704            * (window_size.as_secs() / update_interval.as_secs()) as usize;
705        assert!(
706            mem_estimate < 128_000_000,
707            "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
708        );
709    }
710}