sui_futures/
future.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{future::Future, time::Duration};
5
6use tokio::time::sleep;
7
8/// Wraps a future with slow/stuck detection using `tokio::select!`
9///
10/// This implementation races the future against a timer. If the timer expires first, the callback
11/// is executed (exactly once) but the future continues to run. This approach can detect stuck
12/// futures that never wake their waker.
13pub async fn with_slow_future_monitor<F, C>(
14    future: F,
15    threshold: Duration,
16    callback: C,
17) -> F::Output
18where
19    F: Future,
20    C: FnOnce(),
21{
22    // The select! macro needs to take a reference to the future, which requires it to be pinned
23    tokio::pin!(future);
24
25    tokio::select! {
26        result = &mut future => {
27            // Future completed before timeout
28            return result;
29        }
30        _ = sleep(threshold) => {
31            // Timeout elapsed - fire the warning
32            callback();
33        }
34    }
35
36    // If we get here, the timeout fired but the future is still running. Continue waiting for the
37    // future to complete
38    future.await
39}
40
41#[cfg(test)]
42mod tests {
43    use std::{
44        sync::{
45            Arc,
46            atomic::{AtomicUsize, Ordering},
47        },
48        time::Duration,
49    };
50
51    use tokio::time::{sleep, timeout};
52
53    use super::*;
54
55    #[derive(Clone)]
56    struct Counter(Arc<AtomicUsize>);
57
58    impl Counter {
59        fn new() -> Self {
60            Self(Arc::new(AtomicUsize::new(0)))
61        }
62
63        fn increment(&self) {
64            self.0.fetch_add(1, Ordering::Relaxed);
65        }
66
67        fn count(&self) -> usize {
68            self.0.load(Ordering::Relaxed)
69        }
70    }
71
72    #[tokio::test]
73    async fn slow_monitor_callback_called_once_when_threshold_exceeded() {
74        let c = Counter::new();
75
76        let result = with_slow_future_monitor(
77            async {
78                sleep(Duration::from_millis(200)).await;
79                42 // Return a value to verify completion
80            },
81            Duration::from_millis(100),
82            || c.increment(),
83        )
84        .await;
85
86        assert_eq!(c.count(), 1);
87        assert_eq!(result, 42);
88    }
89
90    #[tokio::test]
91    async fn slow_monitor_callback_not_called_when_threshold_not_exceeded() {
92        let c = Counter::new();
93
94        let result = with_slow_future_monitor(
95            async {
96                sleep(Duration::from_millis(50)).await;
97                42 // Return a value to verify completion
98            },
99            Duration::from_millis(200),
100            || c.increment(),
101        )
102        .await;
103
104        assert_eq!(c.count(), 0);
105        assert_eq!(result, 42);
106    }
107
108    #[tokio::test]
109    async fn slow_monitor_error_propagation() {
110        let c = Counter::new();
111
112        let result: Result<i32, &str> = with_slow_future_monitor(
113            async {
114                sleep(Duration::from_millis(150)).await;
115                Err("Something went wrong")
116            },
117            Duration::from_millis(100),
118            || c.increment(),
119        )
120        .await;
121
122        assert!(result.is_err());
123        assert_eq!(result.unwrap_err(), "Something went wrong");
124        assert_eq!(c.count(), 1);
125    }
126
127    #[tokio::test]
128    async fn slow_monitor_error_propagation_without_callback() {
129        let c = Counter::new();
130
131        let result: Result<i32, &str> = with_slow_future_monitor(
132            async {
133                sleep(Duration::from_millis(50)).await;
134                Err("Quick error")
135            },
136            Duration::from_millis(200),
137            || c.increment(),
138        )
139        .await;
140
141        assert!(result.is_err());
142        assert_eq!(result.unwrap_err(), "Quick error");
143        assert_eq!(c.count(), 0);
144    }
145
146    #[tokio::test]
147    async fn slow_monitor_stuck_future_detection() {
148        use std::future::Future;
149        use std::pin::Pin;
150        use std::task::{Context, Poll};
151
152        // A future that returns Pending but never wakes the waker
153        struct StuckFuture;
154        impl Future for StuckFuture {
155            type Output = ();
156            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
157                Poll::Pending
158            }
159        }
160
161        let c = Counter::new();
162
163        // Even though StuckFuture never wakes, our monitor will detect it!
164        let monitored =
165            with_slow_future_monitor(StuckFuture, Duration::from_millis(200), || c.increment());
166
167        // Use a timeout to prevent the test from hanging
168        timeout(Duration::from_secs(2), monitored)
169            .await
170            .unwrap_err();
171        assert_eq!(c.count(), 1);
172    }
173}