1use std::{iter::Iterator, time::Duration};
5
6use rand::Rng as _;
7
8#[derive(Debug, Clone)]
35pub struct ExponentialBackoff {
36 next: Duration,
37 factor: f64,
38 max_delay: Duration,
39 max_jitter: Duration,
40}
41
42impl ExponentialBackoff {
43 pub fn new(initial_delay: Duration, max_delay: Duration) -> ExponentialBackoff {
45 ExponentialBackoff {
46 next: initial_delay,
47 factor: 1.2,
48 max_delay,
49 max_jitter: initial_delay,
50 }
51 }
52
53 pub fn factor(mut self, factor: f64) -> ExponentialBackoff {
58 self.factor = factor;
59 self
60 }
61
62 pub fn max_jitter(mut self, max_jitter: Duration) -> ExponentialBackoff {
66 self.max_jitter = max_jitter;
67 self
68 }
69}
70
71impl Iterator for ExponentialBackoff {
72 type Item = Duration;
73
74 fn next(&mut self) -> Option<Duration> {
76 let current = self.next;
77
78 let jitter = if self.max_jitter.is_zero() {
79 Duration::ZERO
80 } else {
81 Duration::from_secs_f64(
82 rand::thread_rng().gen_range(0.0..self.max_jitter.as_secs_f64()),
83 )
84 };
85 self.next = current
86 .mul_f64(self.factor)
87 .min(self.max_delay)
88 .saturating_add(jitter);
89
90 Some(current)
91 }
92}
93
94#[test]
95fn test_exponential_backoff_default() {
96 let mut backoff = ExponentialBackoff::new(Duration::from_millis(50), Duration::from_secs(10));
97
98 let bounds = vec![
99 (Duration::from_millis(50), Duration::from_millis(100)),
100 (Duration::from_millis(60), Duration::from_millis(170)),
101 ];
102 for (lower, upper) in bounds {
103 let delay = backoff.next().unwrap();
104 assert!(delay >= lower && delay <= upper);
105 }
106}
107
108#[test]
109fn test_exponential_backoff_base_100_factor_2_no_jitter() {
110 let mut backoff = ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(10))
111 .factor(2.0)
112 .max_jitter(Duration::ZERO);
113
114 assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
115 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
116 assert_eq!(backoff.next(), Some(Duration::from_millis(400)));
117 assert_eq!(backoff.next(), Some(Duration::from_millis(800)));
118}
119
120#[test]
121fn test_exponential_backoff_max_delay() {
122 let mut backoff = ExponentialBackoff::new(Duration::from_millis(200), Duration::from_secs(1))
123 .factor(3.0)
124 .max_jitter(Duration::ZERO);
125
126 assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
127 assert_eq!(backoff.next(), Some(Duration::from_millis(600)));
128
129 for _ in 0..10 {
130 assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
131 }
132}