sui_core/traffic_controller/
policies.rs

1// Copyright (c) 2021, Facebook, Inc. and its affiliates
2// Copyright (c) Mysten Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, net::IpAddr, sync::Arc};
6
7use count_min_sketch::CountMinSketch32;
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::RwLock;
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, VecDeque};
12use std::fmt::Debug;
13use std::hash::Hash;
14use std::time::Duration;
15use std::time::{Instant, SystemTime};
16use sui_types::traffic_control::{FreqThresholdConfig, PolicyConfig, PolicyType, Weight};
17use tracing::{info, trace};
18
19const HIGHEST_RATES_CAPACITY: usize = 20;
20
21/// The type of request client.
22#[derive(Hash, Eq, PartialEq, Debug)]
23enum ClientType {
24    Direct,
25    ThroughFullnode,
26}
27
28#[derive(Hash, Eq, PartialEq, Debug)]
29struct SketchKey {
30    salt: u64,
31    ip_addr: IpAddr,
32    client_type: ClientType,
33}
34
35struct HighestRates {
36    direct: BinaryHeap<Reverse<(u64, IpAddr)>>,
37    proxied: BinaryHeap<Reverse<(u64, IpAddr)>>,
38    capacity: usize,
39}
40
41pub struct TrafficSketch {
42    /// Circular buffer Count Min Sketches representing a sliding window
43    /// of traffic data. Note that the 32 in CountMinSketch32 represents
44    /// the number of bits used to represent the count in the sketch. Since
45    /// we only count on a sketch for a window of `update_interval`, we only
46    /// need enough precision to represent the max expected unique IP addresses
47    /// we may see in that window. For a 10 second period, we might conservatively
48    /// expect 100,000, which can be represented in 17 bits, but not 16. We can
49    /// potentially lower the memory consumption by using CountMinSketch16, which
50    /// will reliably support up to ~65,000 unique IP addresses in the window.
51    sketches: VecDeque<CountMinSketch32<SketchKey>>,
52    window_size: Duration,
53    update_interval: Duration,
54    last_reset_time: Instant,
55    current_sketch_index: usize,
56    /// Used for metrics collection and logging purposes,
57    /// as CountMinSketch does not provide this directly.
58    /// Note that this is an imperfect metric, since we preserve
59    /// the highest N rates (by unique IP) that we have seen,
60    /// but update rates (down or up) as they change so that
61    /// the metric is not monotonic and reflects recent traffic.
62    /// However, this should only lead to inaccuracy edge cases
63    /// with very low traffic.
64    highest_rates: HighestRates,
65}
66
67impl TrafficSketch {
68    pub fn new(
69        window_size: Duration,
70        update_interval: Duration,
71        sketch_capacity: usize,
72        sketch_probability: f64,
73        sketch_tolerance: f64,
74        highest_rates_capacity: usize,
75    ) -> Self {
76        // intentionally round down via integer division. We can't have a partial sketch
77        let num_sketches = window_size.as_secs() / update_interval.as_secs();
78        let new_window_size = Duration::from_secs(num_sketches * update_interval.as_secs());
79        if new_window_size != window_size {
80            info!(
81                "Rounding traffic sketch window size down to {} seconds to make it an integer multiple of update interval {} seconds.",
82                new_window_size.as_secs(),
83                update_interval.as_secs(),
84            );
85        }
86        let window_size = new_window_size;
87
88        assert!(
89            window_size < Duration::from_secs(600),
90            "window_size too large. Max 600 seconds"
91        );
92        assert!(
93            update_interval < window_size,
94            "Update interval may not be larger than window size"
95        );
96        assert!(
97            update_interval >= Duration::from_secs(1),
98            "Update interval too short, must be at least 1 second"
99        );
100        assert!(num_sketches <= 10, "Given parameters require too many sketches to be stored. Reduce window size or increase update interval.");
101        let mem_estimate = (num_sketches as usize)
102            * CountMinSketch32::<IpAddr>::estimate_memory(
103                sketch_capacity,
104                sketch_probability,
105                sketch_tolerance,
106            )
107            .expect("Failed to estimate memory for CountMinSketch32");
108        assert!(mem_estimate < 128_000_000, "Memory estimate for traffic sketch exceeds 128MB. Reduce window size or increase update interval.");
109
110        let mut sketches = VecDeque::with_capacity(num_sketches as usize);
111        for _ in 0..num_sketches {
112            sketches.push_back(
113                CountMinSketch32::<SketchKey>::new(
114                    sketch_capacity,
115                    sketch_probability,
116                    sketch_tolerance,
117                )
118                .expect("Failed to create CountMinSketch32"),
119            );
120        }
121        Self {
122            sketches,
123            window_size,
124            update_interval,
125            last_reset_time: Instant::now(),
126            current_sketch_index: 0,
127            highest_rates: HighestRates {
128                direct: BinaryHeap::with_capacity(highest_rates_capacity),
129                proxied: BinaryHeap::with_capacity(highest_rates_capacity),
130                capacity: highest_rates_capacity,
131            },
132        }
133    }
134
135    fn increment_count(&mut self, key: &SketchKey) {
136        // reset all expired intervals
137        let current_time = Instant::now();
138        let mut elapsed = current_time.duration_since(self.last_reset_time);
139        while elapsed >= self.update_interval {
140            self.rotate_window();
141            elapsed -= self.update_interval;
142        }
143        // Increment in the current active sketch
144        self.sketches[self.current_sketch_index].increment(key);
145    }
146
147    fn get_request_rate(&mut self, key: &SketchKey) -> f64 {
148        let count: u32 = self
149            .sketches
150            .iter()
151            .map(|sketch| sketch.estimate(key))
152            .sum();
153        let rate = count as f64 / self.window_size.as_secs() as f64;
154        self.update_highest_rates(key, rate);
155        rate
156    }
157
158    fn update_highest_rates(&mut self, key: &SketchKey, rate: f64) {
159        match key.client_type {
160            ClientType::Direct => {
161                Self::update_highest_rate(
162                    &mut self.highest_rates.direct,
163                    key.ip_addr,
164                    rate,
165                    self.highest_rates.capacity,
166                );
167            }
168            ClientType::ThroughFullnode => {
169                Self::update_highest_rate(
170                    &mut self.highest_rates.proxied,
171                    key.ip_addr,
172                    rate,
173                    self.highest_rates.capacity,
174                );
175            }
176        }
177    }
178
179    fn update_highest_rate(
180        rate_heap: &mut BinaryHeap<Reverse<(u64, IpAddr)>>,
181        ip_addr: IpAddr,
182        rate: f64,
183        capacity: usize,
184    ) {
185        // Remove previous instance of this IPAddr so that we
186        // can update with new rate
187        rate_heap.retain(|&Reverse((_, key))| key != ip_addr);
188
189        let rate = rate as u64;
190        if rate_heap.len() < capacity {
191            rate_heap.push(Reverse((rate, ip_addr)));
192        } else if let Some(&Reverse((smallest_score, _))) = rate_heap.peek() {
193            if rate > smallest_score {
194                rate_heap.pop();
195                rate_heap.push(Reverse((rate, ip_addr)));
196            }
197        }
198    }
199
200    pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
201        self.highest_rates
202            .direct
203            .iter()
204            .map(|Reverse(v)| v)
205            .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
206            .copied()
207    }
208
209    pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
210        self.highest_rates
211            .proxied
212            .iter()
213            .map(|Reverse(v)| v)
214            .max_by(|a, b| a.0.partial_cmp(&b.0).expect("Failed to compare rates"))
215            .copied()
216    }
217
218    fn rotate_window(&mut self) {
219        self.current_sketch_index = (self.current_sketch_index + 1) % self.sketches.len();
220        self.sketches[self.current_sketch_index].clear();
221        self.last_reset_time = Instant::now();
222    }
223}
224
225#[derive(Clone, Debug)]
226pub struct TrafficTally {
227    pub direct: Option<IpAddr>,
228    pub through_fullnode: Option<IpAddr>,
229    pub error_info: Option<(Weight, String)>,
230    pub spam_weight: Weight,
231    pub timestamp: SystemTime,
232}
233
234impl TrafficTally {
235    pub fn new(
236        direct: Option<IpAddr>,
237        through_fullnode: Option<IpAddr>,
238        error_info: Option<(Weight, String)>,
239        spam_weight: Weight,
240    ) -> Self {
241        Self {
242            direct,
243            through_fullnode,
244            error_info,
245            spam_weight,
246            timestamp: SystemTime::now(),
247        }
248    }
249}
250
251#[derive(Clone, Debug, Default)]
252pub struct PolicyResponse {
253    pub block_client: Option<IpAddr>,
254    pub block_proxied_client: Option<IpAddr>,
255}
256
257pub trait Policy {
258    // returns, e.g. (true, false) if connection_ip should be added to blocklist
259    // and proxy_ip should not
260    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse;
261    fn policy_config(&self) -> &PolicyConfig;
262}
263
264// Nonserializable representation, also note that inner types are
265// not object safe, so we can't use a trait object instead
266pub enum TrafficControlPolicy {
267    FreqThreshold(FreqThresholdPolicy),
268    NoOp(NoOpPolicy),
269    // Test policies below this point
270    TestNConnIP(TestNConnIPPolicy),
271    TestPanicOnInvocation(TestPanicOnInvocationPolicy),
272}
273
274impl Policy for TrafficControlPolicy {
275    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
276        match self {
277            TrafficControlPolicy::NoOp(policy) => policy.handle_tally(tally),
278            TrafficControlPolicy::FreqThreshold(policy) => policy.handle_tally(tally),
279            TrafficControlPolicy::TestNConnIP(policy) => policy.handle_tally(tally),
280            TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.handle_tally(tally),
281        }
282    }
283
284    fn policy_config(&self) -> &PolicyConfig {
285        match self {
286            TrafficControlPolicy::NoOp(policy) => policy.policy_config(),
287            TrafficControlPolicy::FreqThreshold(policy) => policy.policy_config(),
288            TrafficControlPolicy::TestNConnIP(policy) => policy.policy_config(),
289            TrafficControlPolicy::TestPanicOnInvocation(policy) => policy.policy_config(),
290        }
291    }
292}
293
294impl TrafficControlPolicy {
295    pub async fn from_spam_config(policy_config: PolicyConfig) -> Self {
296        Self::from_config(policy_config.clone().spam_policy_type, policy_config).await
297    }
298    pub async fn from_error_config(policy_config: PolicyConfig) -> Self {
299        Self::from_config(policy_config.clone().error_policy_type, policy_config).await
300    }
301    pub async fn from_config(policy_type: PolicyType, policy_config: PolicyConfig) -> Self {
302        match policy_type {
303            PolicyType::NoOp => Self::NoOp(NoOpPolicy::new(policy_config)),
304            PolicyType::FreqThreshold(freq_threshold_config) => Self::FreqThreshold(
305                FreqThresholdPolicy::new(policy_config, freq_threshold_config),
306            ),
307            PolicyType::TestNConnIP(n) => {
308                Self::TestNConnIP(TestNConnIPPolicy::new(policy_config, n).await)
309            }
310            PolicyType::TestPanicOnInvocation => {
311                Self::TestPanicOnInvocation(TestPanicOnInvocationPolicy::new(policy_config))
312            }
313        }
314    }
315}
316
317////////////// *** Policy definitions *** //////////////
318
319pub struct FreqThresholdPolicy {
320    pub config: PolicyConfig,
321    pub client_threshold: u64,
322    pub proxied_client_threshold: u64,
323    sketch: TrafficSketch,
324    /// Unique salt to be added to all keys in the sketch. This
325    /// ensures that false positives are not correlated across
326    /// all nodes at the same time. For Sui validators for example,
327    /// this means that false positives should not prevent the network
328    /// from achieving quorum.
329    salt: u64,
330}
331
332impl FreqThresholdPolicy {
333    pub fn new(
334        config: PolicyConfig,
335        FreqThresholdConfig {
336            client_threshold,
337            proxied_client_threshold,
338            window_size_secs,
339            update_interval_secs,
340            sketch_capacity,
341            sketch_probability,
342            sketch_tolerance,
343        }: FreqThresholdConfig,
344    ) -> Self {
345        let sketch = TrafficSketch::new(
346            Duration::from_secs(window_size_secs),
347            Duration::from_secs(update_interval_secs),
348            sketch_capacity,
349            sketch_probability,
350            sketch_tolerance,
351            HIGHEST_RATES_CAPACITY,
352        );
353        Self {
354            config,
355            sketch,
356            client_threshold,
357            proxied_client_threshold,
358            salt: rand::random(),
359        }
360    }
361
362    pub fn highest_direct_rate(&self) -> Option<(u64, IpAddr)> {
363        self.sketch.highest_direct_rate()
364    }
365
366    pub fn highest_proxied_rate(&self) -> Option<(u64, IpAddr)> {
367        self.sketch.highest_proxied_rate()
368    }
369
370    pub fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
371        let block_client = if let Some(source) = tally.direct {
372            let key = SketchKey {
373                salt: self.salt,
374                ip_addr: source,
375                client_type: ClientType::Direct,
376            };
377            self.sketch.increment_count(&key);
378            let req_rate = self.sketch.get_request_rate(&key);
379            trace!(
380                "FreqThresholdPolicy handling tally -- req_rate: {:?}, client_threshold: {:?}, client: {:?}",
381                req_rate,
382                self.client_threshold,
383                source,
384            );
385            if req_rate >= self.client_threshold as f64 {
386                Some(source)
387            } else {
388                None
389            }
390        } else {
391            None
392        };
393        let block_proxied_client = if let Some(source) = tally.through_fullnode {
394            let key = SketchKey {
395                salt: self.salt,
396                ip_addr: source,
397                client_type: ClientType::ThroughFullnode,
398            };
399            self.sketch.increment_count(&key);
400            if self.sketch.get_request_rate(&key) >= self.proxied_client_threshold as f64 {
401                Some(source)
402            } else {
403                None
404            }
405        } else {
406            None
407        };
408        PolicyResponse {
409            block_client,
410            block_proxied_client,
411        }
412    }
413
414    fn policy_config(&self) -> &PolicyConfig {
415        &self.config
416    }
417}
418
419////////////// *** Test policies below this point *** //////////////
420
421#[derive(Clone)]
422pub struct NoOpPolicy {
423    config: PolicyConfig,
424}
425
426impl NoOpPolicy {
427    pub fn new(config: PolicyConfig) -> Self {
428        Self { config }
429    }
430
431    fn handle_tally(&mut self, _tally: TrafficTally) -> PolicyResponse {
432        PolicyResponse::default()
433    }
434
435    fn policy_config(&self) -> &PolicyConfig {
436        &self.config
437    }
438}
439
440#[derive(Clone)]
441pub struct TestNConnIPPolicy {
442    pub threshold: u64,
443    pub config: PolicyConfig,
444    frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>,
445}
446
447impl TestNConnIPPolicy {
448    pub async fn new(config: PolicyConfig, threshold: u64) -> Self {
449        let frequencies = Arc::new(RwLock::new(HashMap::new()));
450        let frequencies_clone = frequencies.clone();
451        spawn_monitored_task!(run_clear_frequencies(
452            frequencies_clone,
453            config.connection_blocklist_ttl_sec * 2,
454        ));
455        Self {
456            config,
457            frequencies,
458            threshold,
459        }
460    }
461
462    fn handle_tally(&mut self, tally: TrafficTally) -> PolicyResponse {
463        let client = if let Some(client) = tally.direct {
464            client
465        } else {
466            return PolicyResponse::default();
467        };
468
469        // increment the count for the IP
470        let mut frequencies = self.frequencies.write();
471        let count = frequencies.entry(client).or_insert(0);
472        *count += 1;
473        PolicyResponse {
474            block_client: if *count >= self.threshold {
475                Some(client)
476            } else {
477                None
478            },
479            block_proxied_client: None,
480        }
481    }
482
483    fn policy_config(&self) -> &PolicyConfig {
484        &self.config
485    }
486}
487
488async fn run_clear_frequencies(frequencies: Arc<RwLock<HashMap<IpAddr, u64>>>, window_secs: u64) {
489    loop {
490        tokio::time::sleep(tokio::time::Duration::from_secs(window_secs)).await;
491        frequencies.write().clear();
492    }
493}
494
495#[derive(Clone)]
496pub struct TestPanicOnInvocationPolicy {
497    config: PolicyConfig,
498}
499
500impl TestPanicOnInvocationPolicy {
501    pub fn new(config: PolicyConfig) -> Self {
502        Self { config }
503    }
504
505    fn handle_tally(&mut self, _: TrafficTally) -> PolicyResponse {
506        panic!("Tally for this policy should never be invoked")
507    }
508
509    fn policy_config(&self) -> &PolicyConfig {
510        &self.config
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use std::net::{IpAddr, Ipv4Addr};
518    use sui_macros::sim_test;
519    use sui_types::traffic_control::{
520        DEFAULT_SKETCH_CAPACITY, DEFAULT_SKETCH_PROBABILITY, DEFAULT_SKETCH_TOLERANCE,
521    };
522
523    #[sim_test]
524    async fn test_freq_threshold_policy() {
525        // Create freq policy that will block on average frequency 2 requests per second
526        // for proxied connections and 4 requests per second for direct connections
527        // as observed over a 5 second window.
528        let mut policy = FreqThresholdPolicy::new(
529            PolicyConfig::default(),
530            FreqThresholdConfig {
531                client_threshold: 5,
532                proxied_client_threshold: 2,
533                window_size_secs: 5,
534                update_interval_secs: 1,
535                ..Default::default()
536            },
537        );
538        // alice and bob connection from different IPs through the
539        // same fullnode, thus have the same connection IP on
540        // validator, but different proxy IPs
541        let alice = TrafficTally {
542            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
543            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))),
544            error_info: None,
545            spam_weight: Weight::one(),
546            timestamp: SystemTime::now(),
547        };
548        let bob = TrafficTally {
549            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
550            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(4, 3, 2, 1))),
551            error_info: None,
552            spam_weight: Weight::one(),
553            timestamp: SystemTime::now(),
554        };
555        let charlie = TrafficTally {
556            direct: Some(IpAddr::V4(Ipv4Addr::new(8, 7, 6, 5))),
557            through_fullnode: Some(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))),
558            error_info: None,
559            spam_weight: Weight::one(),
560            timestamp: SystemTime::now(),
561        };
562
563        // initial 2 tallies for alice, should not block
564        for _ in 0..2 {
565            let response = policy.handle_tally(alice.clone());
566            assert_eq!(response.block_proxied_client, None);
567            assert_eq!(response.block_client, None);
568        }
569
570        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
571        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
572        assert_eq!(direct_ip_addr, alice.direct.unwrap());
573        assert!(direct_rate < 1);
574        assert_eq!(proxied_ip_addr, alice.through_fullnode.unwrap());
575        assert!(proxied_rate < 1);
576
577        // meanwhile bob spams 10 requests at once and is blocked
578        for _ in 0..9 {
579            let response = policy.handle_tally(bob.clone());
580            assert_eq!(response.block_client, None);
581            assert_eq!(response.block_proxied_client, None);
582        }
583        let response = policy.handle_tally(bob.clone());
584        assert_eq!(response.block_client, None);
585        assert_eq!(response.block_proxied_client, bob.through_fullnode);
586
587        // highest rates should now show bob
588        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
589        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
590        assert_eq!(direct_ip_addr, bob.direct.unwrap());
591        assert_eq!(direct_rate, 2);
592        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
593        assert_eq!(proxied_rate, 2);
594
595        // 2 more tallies, so far we are above 2 tallies
596        // per second, but over the average window of 5 seconds
597        // we are still below the threshold. Should not block
598        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
599        for _ in 0..2 {
600            let response = policy.handle_tally(alice.clone());
601            assert_eq!(response.block_client, None);
602            assert_eq!(response.block_proxied_client, None);
603        }
604        // bob is no longer blocked, as we moved past the bursty traffic
605        // in the sliding window
606        let _ = policy.handle_tally(bob.clone());
607        let response = policy.handle_tally(bob.clone());
608        assert_eq!(response.block_client, None);
609        assert_eq!(response.block_proxied_client, bob.through_fullnode);
610
611        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
612        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
613        // direct rate increased due to alice going through same fullnode
614        assert_eq!(direct_ip_addr, alice.direct.unwrap());
615        assert_eq!(direct_rate, 3);
616        // highest rate should now have been updated given that Bob's rate
617        // recently decreased
618        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
619        assert_eq!(proxied_rate, 2);
620
621        // close to threshold for alice, but still below
622        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
623        for i in 0..5 {
624            let response = policy.handle_tally(alice.clone());
625            assert_eq!(response.block_client, None, "Blocked at i = {}", i);
626            assert_eq!(response.block_proxied_client, None);
627        }
628
629        // should block alice now
630        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
631        let response = policy.handle_tally(alice.clone());
632        assert_eq!(response.block_client, None);
633        assert_eq!(response.block_proxied_client, alice.through_fullnode);
634
635        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
636        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
637        assert_eq!(direct_ip_addr, alice.direct.unwrap());
638        assert_eq!(direct_rate, 4);
639        assert_eq!(proxied_ip_addr, bob.through_fullnode.unwrap());
640        assert_eq!(proxied_rate, 2);
641
642        // spam through charlie to block connection
643        for i in 0..2 {
644            let response = policy.handle_tally(charlie.clone());
645            assert_eq!(response.block_client, None, "Blocked at i = {}", i);
646            assert_eq!(response.block_proxied_client, None);
647        }
648        // Now we block connection IP
649        let response = policy.handle_tally(charlie.clone());
650        assert_eq!(response.block_proxied_client, None);
651        assert_eq!(response.block_client, charlie.direct);
652
653        // Ensure that if we wait another second, we are no longer blocked
654        // as the bursty first second has finally rotated out of the sliding
655        // window
656        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
657        for i in 0..3 {
658            let response = policy.handle_tally(charlie.clone());
659            assert_eq!(response.block_client, None, "Blocked at i = {}", i);
660            assert_eq!(response.block_proxied_client, None);
661        }
662
663        // check that we revert back to previous highest rates when rates decrease
664        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
665        // alice and bob rates are now decreased after 5 seconds of break, but charlie's
666        // has not since they have not yet sent a new request
667        let _ = policy.handle_tally(alice.clone());
668        let _ = policy.handle_tally(bob.clone());
669        let (direct_rate, direct_ip_addr) = policy.highest_direct_rate().unwrap();
670        let (proxied_rate, proxied_ip_addr) = policy.highest_proxied_rate().unwrap();
671        assert_eq!(direct_ip_addr, alice.direct.unwrap());
672        assert_eq!(direct_rate, 0);
673        assert_eq!(proxied_ip_addr, charlie.through_fullnode.unwrap());
674        assert_eq!(proxied_rate, 1);
675    }
676
677    #[sim_test]
678    async fn test_traffic_sketch_mem_estimate() {
679        // Test for getting a rough estimate of memory usage for the traffic sketch
680        // given certain parameters. Parameters below are the default.
681        // With default parameters, memory estimate is 113 MB.
682        let window_size = Duration::from_secs(30);
683        let update_interval = Duration::from_secs(5);
684        let mem_estimate = CountMinSketch32::<IpAddr>::estimate_memory(
685            DEFAULT_SKETCH_CAPACITY,
686            DEFAULT_SKETCH_PROBABILITY,
687            DEFAULT_SKETCH_TOLERANCE,
688        )
689        .unwrap()
690            * (window_size.as_secs() / update_interval.as_secs()) as usize;
691        assert!(
692            mem_estimate < 128_000_000,
693            "Memory estimate {mem_estimate} for traffic sketch exceeds 128MB."
694        );
695    }
696}