mysten_common/
backoff.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{iter::Iterator, time::Duration};
5
6use rand::Rng as _;
7
8/// Creates a generator which yields an approximately exponential series of durations, as back-off delays.
9/// Jitters are added to each delay by default to prevent thundering herd on retries.
10///
11/// The API is inspired by tokio-retry::strategy::ExponentialBackoff for ease of use.
12/// But bugs in the original implementation have been fixed.
13///
14/// ```rust,no_run
15/// use std::time::Duration;
16/// use mysten_common::backoff::ExponentialBackoff;
17///
18/// // Basic example:
19/// let mut backoff = ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(10));
20/// for (attempt, delay) in backoff.enumerate() {
21///     println!("Attempt {attempt}: Delay: {:?}", delay);
22/// }
23///
24/// // Specifying backoff factor and maximum jitter:
25/// let mut backoff = ExponentialBackoff::new(Duration::from_secs(5), Duration::from_secs(60))
26///     .factor(2.0)
27///     .max_jitter(Duration::from_secs(1));
28/// loop {
29///     // next() should always return a Some(Duration).
30///     let delay = backoff.next().unwrap();
31///     println!("Delay: {:?}", delay);
32/// }
33/// ```
34#[derive(Debug, Clone)]
35pub struct ExponentialBackoff {
36    next: Duration,
37    factor: f64,
38    max_delay: Duration,
39    max_jitter: Duration,
40}
41
42impl ExponentialBackoff {
43    /// Constructs a new exponential backoff generator, specifying the maximum delay.
44    pub fn new(initial_delay: Duration, max_delay: Duration) -> ExponentialBackoff {
45        ExponentialBackoff {
46            next: initial_delay,
47            factor: 1.2,
48            max_delay,
49            max_jitter: initial_delay,
50        }
51    }
52
53    /// Sets the approximate ratio of consecutive backoff delays, before jitters are applied.
54    /// Setting this to Duration::ZERO disables jittering.
55    ///
56    /// Default is 1.2.
57    pub fn factor(mut self, factor: f64) -> ExponentialBackoff {
58        self.factor = factor;
59        self
60    }
61
62    /// Sets the maximum jitter per delay.
63    ///
64    /// Default is the initial delay.
65    pub fn max_jitter(mut self, max_jitter: Duration) -> ExponentialBackoff {
66        self.max_jitter = max_jitter;
67        self
68    }
69}
70
71impl Iterator for ExponentialBackoff {
72    type Item = Duration;
73
74    /// Yields backoff delays. Never terminates.
75    fn next(&mut self) -> Option<Duration> {
76        let current = self.next;
77
78        let jitter = if self.max_jitter.is_zero() {
79            Duration::ZERO
80        } else {
81            Duration::from_secs_f64(
82                rand::thread_rng().gen_range(0.0..self.max_jitter.as_secs_f64()),
83            )
84        };
85        self.next = current
86            .mul_f64(self.factor)
87            .min(self.max_delay)
88            .saturating_add(jitter);
89
90        Some(current)
91    }
92}
93
94#[test]
95fn test_exponential_backoff_default() {
96    let mut backoff = ExponentialBackoff::new(Duration::from_millis(50), Duration::from_secs(10));
97
98    let bounds = vec![
99        (Duration::from_millis(50), Duration::from_millis(100)),
100        (Duration::from_millis(60), Duration::from_millis(170)),
101    ];
102    for ((lower, upper), delay) in bounds.into_iter().zip(backoff.next()) {
103        assert!(delay >= lower && delay <= upper);
104    }
105}
106
107#[test]
108fn test_exponential_backoff_base_100_factor_2_no_jitter() {
109    let mut backoff = ExponentialBackoff::new(Duration::from_millis(100), Duration::from_secs(10))
110        .factor(2.0)
111        .max_jitter(Duration::ZERO);
112
113    assert_eq!(backoff.next(), Some(Duration::from_millis(100)));
114    assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
115    assert_eq!(backoff.next(), Some(Duration::from_millis(400)));
116    assert_eq!(backoff.next(), Some(Duration::from_millis(800)));
117}
118
119#[test]
120fn test_exponential_backoff_max_delay() {
121    let mut backoff = ExponentialBackoff::new(Duration::from_millis(200), Duration::from_secs(1))
122        .factor(3.0)
123        .max_jitter(Duration::ZERO);
124
125    assert_eq!(backoff.next(), Some(Duration::from_millis(200)));
126    assert_eq!(backoff.next(), Some(Duration::from_millis(600)));
127
128    for _ in 0..10 {
129        assert_eq!(backoff.next(), Some(Duration::from_secs(1)));
130    }
131}