sui_indexer_alt_framework/
task.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{future::Future, panic, pin::pin, time::Duration};
5
6use futures::stream::{Stream, StreamExt};
7use tokio::{task::JoinSet, time::sleep};
8
9/// Extension trait introducing `try_for_each_spawned` to all streams.
10pub trait TrySpawnStreamExt: Stream {
11    /// Attempts to run this stream to completion, executing the provided asynchronous closure on
12    /// each element from the stream as elements become available.
13    ///
14    /// This is similar to [StreamExt::for_each_concurrent], but it may take advantage of any
15    /// parallelism available in the underlying runtime, because each unit of work is spawned as
16    /// its own tokio task.
17    ///
18    /// The first argument is an optional limit on the number of tasks to spawn concurrently.
19    /// Values of `0` and `None` are interpreted as no limit, and any other value will result in no
20    /// more than that many tasks being spawned at one time.
21    ///
22    /// ## Safety
23    ///
24    /// This function will panic if any of its futures panics, will return early with success if
25    /// the runtime it is running on is cancelled, and will return early with an error propagated
26    /// from any worker that produces an error.
27    fn try_for_each_spawned<Fut, F, E>(
28        self,
29        limit: impl Into<Option<usize>>,
30        f: F,
31    ) -> impl Future<Output = Result<(), E>>
32    where
33        Fut: Future<Output = Result<(), E>> + Send + 'static,
34        F: FnMut(Self::Item) -> Fut,
35        E: Send + 'static;
36}
37
38impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
39    async fn try_for_each_spawned<Fut, F, E>(
40        self,
41        limit: impl Into<Option<usize>>,
42        mut f: F,
43    ) -> Result<(), E>
44    where
45        Fut: Future<Output = Result<(), E>> + Send + 'static,
46        F: FnMut(Self::Item) -> Fut,
47        E: Send + 'static,
48    {
49        // Maximum number of tasks to spawn concurrently.
50        let limit = match limit.into() {
51            Some(0) | None => usize::MAX,
52            Some(n) => n,
53        };
54
55        // Number of permits to spawn tasks left.
56        let mut permits = limit;
57        // Handles for already spawned tasks.
58        let mut join_set = JoinSet::new();
59        // Whether the worker pool has stopped accepting new items and is draining.
60        let mut draining = false;
61        // Error that occurred in one of the workers, to be propagated to the called on exit.
62        let mut error = None;
63
64        let mut self_ = pin!(self);
65
66        loop {
67            tokio::select! {
68                next = self_.next(), if !draining && permits > 0 => {
69                    if let Some(item) = next {
70                        permits -= 1;
71                        join_set.spawn(f(item));
72                    } else {
73                        // If the stream is empty, signal that the worker pool is going to
74                        // start draining now, so that once we get all our permits back, we
75                        // know we can wind down the pool.
76                        draining = true;
77                    }
78                }
79
80                Some(res) = join_set.join_next() => {
81                    match res {
82                        Ok(Err(e)) if error.is_none() => {
83                            error = Some(e);
84                            permits += 1;
85                            draining = true;
86                        }
87
88                        Ok(_) => permits += 1,
89
90                        // Worker panicked, propagate the panic.
91                        Err(e) if e.is_panic() => {
92                            panic::resume_unwind(e.into_panic())
93                        }
94
95                        // Worker was cancelled -- this can only happen if its join handle was
96                        // cancelled (not possible because that was created in this function),
97                        // or the runtime it was running in was wound down, in which case,
98                        // prepare the worker pool to drain.
99                        Err(e) => {
100                            assert!(e.is_cancelled());
101                            permits += 1;
102                            draining = true;
103                        }
104                    }
105                }
106
107                else => {
108                    // Not accepting any more items from the stream, and all our workers are
109                    // idle, so we stop.
110                    if permits == limit && draining {
111                        break;
112                    }
113                }
114            }
115        }
116
117        if let Some(e) = error { Err(e) } else { Ok(()) }
118    }
119}
120
121/// Wraps a future with slow/stuck detection using `tokio::select!`
122///
123/// This implementation races the future against a timer. If the timer expires first, the callback
124/// is executed (exactly once) but the future continues to run. This approach can detect stuck
125/// futures that never wake their waker.
126pub async fn with_slow_future_monitor<F, C>(
127    future: F,
128    threshold: Duration,
129    callback: C,
130) -> F::Output
131where
132    F: Future,
133    C: FnOnce(),
134{
135    // The select! macro needs to take a reference to the future, which requires it to be pinned
136    tokio::pin!(future);
137
138    tokio::select! {
139        result = &mut future => {
140            // Future completed before timeout
141            return result;
142        }
143        _ = sleep(threshold) => {
144            // Timeout elapsed - fire the warning
145            callback();
146        }
147    }
148
149    // If we get here, the timeout fired but the future is still running. Continue waiting for the
150    // future to complete
151    future.await
152}
153
154#[cfg(test)]
155mod tests {
156    use std::{
157        sync::{
158            Arc, Mutex,
159            atomic::{AtomicUsize, Ordering},
160        },
161        time::Duration,
162    };
163
164    use futures::stream;
165    use tokio::time::timeout;
166
167    use super::*;
168
169    #[derive(Clone)]
170    struct Counter(Arc<AtomicUsize>);
171
172    impl Counter {
173        fn new() -> Self {
174            Self(Arc::new(AtomicUsize::new(0)))
175        }
176
177        fn increment(&self) {
178            self.0.fetch_add(1, Ordering::Relaxed);
179        }
180
181        fn count(&self) -> usize {
182            self.0.load(Ordering::Relaxed)
183        }
184    }
185
186    #[tokio::test]
187    async fn for_each_explicit_sequential_iteration() {
188        let actual = Arc::new(Mutex::new(vec![]));
189        let result = stream::iter(0..20)
190            .try_for_each_spawned(1, |i| {
191                let actual = actual.clone();
192                async move {
193                    tokio::time::sleep(Duration::from_millis(20 - i)).await;
194                    actual.lock().unwrap().push(i);
195                    Ok::<(), ()>(())
196                }
197            })
198            .await;
199
200        assert!(result.is_ok());
201
202        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
203        let expect: Vec<_> = (0..20).collect();
204        assert_eq!(expect, actual);
205    }
206
207    #[tokio::test]
208    async fn for_each_concurrent_iteration() {
209        let actual = Arc::new(AtomicUsize::new(0));
210        let result = stream::iter(0..100)
211            .try_for_each_spawned(16, |i| {
212                let actual = actual.clone();
213                async move {
214                    actual.fetch_add(i, Ordering::Relaxed);
215                    Ok::<(), ()>(())
216                }
217            })
218            .await;
219
220        assert!(result.is_ok());
221
222        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
223        let expect = 99 * 100 / 2;
224        assert_eq!(expect, actual);
225    }
226
227    #[tokio::test]
228    async fn for_each_implicit_unlimited_iteration() {
229        let actual = Arc::new(AtomicUsize::new(0));
230        let result = stream::iter(0..100)
231            .try_for_each_spawned(None, |i| {
232                let actual = actual.clone();
233                async move {
234                    actual.fetch_add(i, Ordering::Relaxed);
235                    Ok::<(), ()>(())
236                }
237            })
238            .await;
239
240        assert!(result.is_ok());
241
242        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
243        let expect = 99 * 100 / 2;
244        assert_eq!(expect, actual);
245    }
246
247    #[tokio::test]
248    async fn for_each_explicit_unlimited_iteration() {
249        let actual = Arc::new(AtomicUsize::new(0));
250        let result = stream::iter(0..100)
251            .try_for_each_spawned(0, |i| {
252                let actual = actual.clone();
253                async move {
254                    actual.fetch_add(i, Ordering::Relaxed);
255                    Ok::<(), ()>(())
256                }
257            })
258            .await;
259
260        assert!(result.is_ok());
261
262        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
263        let expect = 99 * 100 / 2;
264        assert_eq!(expect, actual);
265    }
266
267    #[tokio::test]
268    async fn for_each_max_concurrency() {
269        #[derive(Default, Debug)]
270        struct Jobs {
271            max: AtomicUsize,
272            curr: AtomicUsize,
273        }
274
275        let jobs = Arc::new(Jobs::default());
276
277        let result = stream::iter(0..32)
278            .try_for_each_spawned(4, |_| {
279                let jobs = jobs.clone();
280                async move {
281                    jobs.curr.fetch_add(1, Ordering::Relaxed);
282                    tokio::time::sleep(Duration::from_millis(100)).await;
283                    let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
284                    jobs.max.fetch_max(prev, Ordering::Relaxed);
285                    Ok::<(), ()>(())
286                }
287            })
288            .await;
289
290        assert!(result.is_ok());
291
292        let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
293        assert_eq!(curr.into_inner(), 0);
294        assert!(max.into_inner() <= 4);
295    }
296
297    #[tokio::test]
298    async fn for_each_error_propagation() {
299        let actual = Arc::new(Mutex::new(vec![]));
300        let result = stream::iter(0..100)
301            .try_for_each_spawned(None, |i| {
302                let actual = actual.clone();
303                async move {
304                    if i < 42 {
305                        actual.lock().unwrap().push(i);
306                        Ok(())
307                    } else {
308                        Err(())
309                    }
310                }
311            })
312            .await;
313
314        assert!(result.is_err());
315
316        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
317        let expect: Vec<_> = (0..42).collect();
318        assert_eq!(expect, actual);
319    }
320
321    #[tokio::test]
322    #[should_panic]
323    async fn for_each_panic_propagation() {
324        let _ = stream::iter(0..100)
325            .try_for_each_spawned(None, |i| async move {
326                assert!(i < 42);
327                Ok::<(), ()>(())
328            })
329            .await;
330    }
331
332    #[tokio::test]
333    async fn slow_monitor_callback_called_once_when_threshold_exceeded() {
334        let c = Counter::new();
335
336        let result = with_slow_future_monitor(
337            async {
338                sleep(Duration::from_millis(200)).await;
339                42 // Return a value to verify completion
340            },
341            Duration::from_millis(100),
342            || c.increment(),
343        )
344        .await;
345
346        assert_eq!(c.count(), 1);
347        assert_eq!(result, 42);
348    }
349
350    #[tokio::test]
351    async fn slow_monitor_callback_not_called_when_threshold_not_exceeded() {
352        let c = Counter::new();
353
354        let result = with_slow_future_monitor(
355            async {
356                sleep(Duration::from_millis(50)).await;
357                42 // Return a value to verify completion
358            },
359            Duration::from_millis(200),
360            || c.increment(),
361        )
362        .await;
363
364        assert_eq!(c.count(), 0);
365        assert_eq!(result, 42);
366    }
367
368    #[tokio::test]
369    async fn slow_monitor_error_propagation() {
370        let c = Counter::new();
371
372        let result: Result<i32, &str> = with_slow_future_monitor(
373            async {
374                sleep(Duration::from_millis(150)).await;
375                Err("Something went wrong")
376            },
377            Duration::from_millis(100),
378            || c.increment(),
379        )
380        .await;
381
382        assert!(result.is_err());
383        assert_eq!(result.unwrap_err(), "Something went wrong");
384        assert_eq!(c.count(), 1);
385    }
386
387    #[tokio::test]
388    async fn slow_monitor_error_propagation_without_callback() {
389        let c = Counter::new();
390
391        let result: Result<i32, &str> = with_slow_future_monitor(
392            async {
393                sleep(Duration::from_millis(50)).await;
394                Err("Quick error")
395            },
396            Duration::from_millis(200),
397            || c.increment(),
398        )
399        .await;
400
401        assert!(result.is_err());
402        assert_eq!(result.unwrap_err(), "Quick error");
403        assert_eq!(c.count(), 0);
404    }
405
406    #[tokio::test]
407    async fn slow_monitor_stuck_future_detection() {
408        use std::future::Future;
409        use std::pin::Pin;
410        use std::task::{Context, Poll};
411
412        // A future that returns Pending but never wakes the waker
413        struct StuckFuture;
414        impl Future for StuckFuture {
415            type Output = ();
416            fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
417                Poll::Pending
418            }
419        }
420
421        let c = Counter::new();
422
423        // Even though StuckFuture never wakes, our monitor will detect it!
424        let monitored =
425            with_slow_future_monitor(StuckFuture, Duration::from_millis(200), || c.increment());
426
427        // Use a timeout to prevent the test from hanging
428        timeout(Duration::from_secs(2), monitored)
429            .await
430            .unwrap_err();
431        assert_eq!(c.count(), 1);
432    }
433}