1use std::{iter::Iterator, time::Duration};
5
6use rand::Rng as _;
7
8#[derive(Debug, Clone)]
36pub struct ExponentialBackoff {
37 current: Duration,
38 factor: f64,
39 max_delay: Duration,
40 max_jitter: Duration,
41}
42
43impl ExponentialBackoff {
44 pub fn new(max_delay: Duration) -> ExponentialBackoff {
46 ExponentialBackoff {
47 current: Duration::from_millis(50),
48 factor: 1.5,
49 max_delay,
50 max_jitter: Duration::from_millis(50),
51 }
52 }
53
54 pub fn base(mut self, base: Duration) -> ExponentialBackoff {
58 self.current = base;
59 self
60 }
61
62 pub fn factor(mut self, factor: f64) -> ExponentialBackoff {
67 self.factor = factor;
68 self
69 }
70
71 pub fn max_jitter(mut self, max_jitter: Duration) -> ExponentialBackoff {
74 self.max_jitter = max_jitter;
75 self
76 }
77}
78
79impl Iterator for ExponentialBackoff {
80 type Item = Duration;
81
82 fn next(&mut self) -> Option<Duration> {
84 let jitter = if self.max_jitter.is_zero() {
85 Duration::ZERO
86 } else {
87 Duration::from_secs_f64(
88 rand::thread_rng().gen_range(0.0..self.max_jitter.as_secs_f64()),
89 )
90 };
91 let next = self
92 .current
93 .mul_f64(self.factor)
94 .min(self.max_delay)
95 .saturating_add(jitter);
96
97 self.current = next;
98
99 Some(next)
100 }
101}
102
103#[test]
104fn test_exponential_backoff_default() {
105 let mut backoff = ExponentialBackoff::new(Duration::from_secs(10));
106
107 let bounds = vec![
108 (Duration::from_millis(75), Duration::from_millis(125)),
109 (Duration::from_millis(110), Duration::from_millis(250)),
110 ];
111 for ((lower, upper), delay) in bounds.into_iter().zip(backoff.next()) {
112 assert!(delay >= lower && delay <= upper);
113 }
114}
115
116#[test]
117fn test_exponential_backoff_base_100_factor_2_no_jitter() {
118 let mut backoff = ExponentialBackoff::new(Duration::from_secs(10))
119 .base(Duration::from_millis(100))
120 .factor(2.0)
121 .max_jitter(Duration::ZERO);
122
123 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
124 assert_eq!(backoff.next(), Some(Duration::from_millis(400)));
125 assert_eq!(backoff.next(), Some(Duration::from_millis(800)));
126}
127
128#[test]
129fn test_exponential_backoff_max_delay() {
130 let mut backoff = ExponentialBackoff::new(Duration::from_secs(1))
131 .base(Duration::from_millis(200))
132 .factor(3.0)
133 .max_jitter(Duration::ZERO);
134
135 assert_eq!(backoff.next(), Some(Duration::from_millis(600)));
136
137 for _ in 0..10 {
138 assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
139 }
140}