sui_futures/
stream.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{future::Future, panic, pin::pin};
5
6use futures::stream::{Stream, StreamExt};
7use tokio::task::JoinSet;
8
9/// Extension trait introducing `try_for_each_spawned` to all streams.
10pub trait TrySpawnStreamExt: Stream {
11    /// Attempts to run this stream to completion, executing the provided asynchronous closure on
12    /// each element from the stream as elements become available.
13    ///
14    /// This is similar to [StreamExt::for_each_concurrent], but it may take advantage of any
15    /// parallelism available in the underlying runtime, because each unit of work is spawned as
16    /// its own tokio task.
17    ///
18    /// The first argument is an optional limit on the number of tasks to spawn concurrently.
19    /// Values of `0` and `None` are interpreted as no limit, and any other value will result in no
20    /// more than that many tasks being spawned at one time.
21    ///
22    /// ## Safety
23    ///
24    /// This function will panic if any of its futures panics, will return early with success if
25    /// the runtime it is running on is cancelled, and will return early with an error propagated
26    /// from any worker that produces an error.
27    fn try_for_each_spawned<Fut, F, E>(
28        self,
29        limit: impl Into<Option<usize>>,
30        f: F,
31    ) -> impl Future<Output = Result<(), E>>
32    where
33        Fut: Future<Output = Result<(), E>> + Send + 'static,
34        F: FnMut(Self::Item) -> Fut,
35        E: Send + 'static;
36}
37
38/// Wrapper type for errors to allow the body of a `try_for_each_spawned` call to signal that it
39/// either wants to return early (`Break`) out of the loop, or propagate an error (`Err(E)`).
40pub enum Break<E> {
41    Break,
42    Err(E),
43}
44
45impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
46    async fn try_for_each_spawned<Fut, F, E>(
47        self,
48        limit: impl Into<Option<usize>>,
49        mut f: F,
50    ) -> Result<(), E>
51    where
52        Fut: Future<Output = Result<(), E>> + Send + 'static,
53        F: FnMut(Self::Item) -> Fut,
54        E: Send + 'static,
55    {
56        // Maximum number of tasks to spawn concurrently.
57        let limit = match limit.into() {
58            Some(0) | None => usize::MAX,
59            Some(n) => n,
60        };
61
62        // Number of permits to spawn tasks left.
63        let mut permits = limit;
64        // Handles for already spawned tasks.
65        let mut join_set = JoinSet::new();
66        // Whether the worker pool has stopped accepting new items and is draining.
67        let mut draining = false;
68        // Error that occurred in one of the workers, to be propagated to the called on exit.
69        let mut error = None;
70
71        let mut self_ = pin!(self);
72
73        loop {
74            tokio::select! {
75                next = self_.next(), if !draining && permits > 0 => {
76                    if let Some(item) = next {
77                        permits -= 1;
78                        join_set.spawn(f(item));
79                    } else {
80                        // If the stream is empty, signal that the worker pool is going to
81                        // start draining now, so that once we get all our permits back, we
82                        // know we can wind down the pool.
83                        draining = true;
84                    }
85                }
86
87                Some(res) = join_set.join_next() => {
88                    match res {
89                        Ok(Err(e)) if error.is_none() => {
90                            error = Some(e);
91                            permits += 1;
92                            draining = true;
93                        }
94
95                        Ok(_) => permits += 1,
96
97                        // Worker panicked, propagate the panic.
98                        Err(e) if e.is_panic() => {
99                            panic::resume_unwind(e.into_panic())
100                        }
101
102                        // Worker was cancelled -- this can only happen if its join handle was
103                        // cancelled (not possible because that was created in this function),
104                        // or the runtime it was running in was wound down, in which case,
105                        // prepare the worker pool to drain.
106                        Err(e) => {
107                            assert!(e.is_cancelled());
108                            permits += 1;
109                            draining = true;
110                        }
111                    }
112                }
113
114                else => {
115                    // Not accepting any more items from the stream, and all our workers are
116                    // idle, so we stop.
117                    if permits == limit && draining {
118                        break;
119                    }
120                }
121            }
122        }
123
124        if let Some(e) = error { Err(e) } else { Ok(()) }
125    }
126}
127
128impl<E> From<E> for Break<E> {
129    fn from(e: E) -> Self {
130        Break::Err(e)
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::{
137        sync::{
138            Arc, Mutex,
139            atomic::{AtomicUsize, Ordering},
140        },
141        time::Duration,
142    };
143
144    use futures::stream;
145
146    use super::*;
147
148    #[tokio::test]
149    async fn for_each_explicit_sequential_iteration() {
150        let actual = Arc::new(Mutex::new(vec![]));
151        let result = stream::iter(0..20)
152            .try_for_each_spawned(1, |i| {
153                let actual = actual.clone();
154                async move {
155                    tokio::time::sleep(Duration::from_millis(20 - i)).await;
156                    actual.lock().unwrap().push(i);
157                    Ok::<(), ()>(())
158                }
159            })
160            .await;
161
162        assert!(result.is_ok());
163
164        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
165        let expect: Vec<_> = (0..20).collect();
166        assert_eq!(expect, actual);
167    }
168
169    #[tokio::test]
170    async fn for_each_concurrent_iteration() {
171        let actual = Arc::new(AtomicUsize::new(0));
172        let result = stream::iter(0..100)
173            .try_for_each_spawned(16, |i| {
174                let actual = actual.clone();
175                async move {
176                    actual.fetch_add(i, Ordering::Relaxed);
177                    Ok::<(), ()>(())
178                }
179            })
180            .await;
181
182        assert!(result.is_ok());
183
184        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
185        let expect = 99 * 100 / 2;
186        assert_eq!(expect, actual);
187    }
188
189    #[tokio::test]
190    async fn for_each_implicit_unlimited_iteration() {
191        let actual = Arc::new(AtomicUsize::new(0));
192        let result = stream::iter(0..100)
193            .try_for_each_spawned(None, |i| {
194                let actual = actual.clone();
195                async move {
196                    actual.fetch_add(i, Ordering::Relaxed);
197                    Ok::<(), ()>(())
198                }
199            })
200            .await;
201
202        assert!(result.is_ok());
203
204        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
205        let expect = 99 * 100 / 2;
206        assert_eq!(expect, actual);
207    }
208
209    #[tokio::test]
210    async fn for_each_explicit_unlimited_iteration() {
211        let actual = Arc::new(AtomicUsize::new(0));
212        let result = stream::iter(0..100)
213            .try_for_each_spawned(0, |i| {
214                let actual = actual.clone();
215                async move {
216                    actual.fetch_add(i, Ordering::Relaxed);
217                    Ok::<(), ()>(())
218                }
219            })
220            .await;
221
222        assert!(result.is_ok());
223
224        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
225        let expect = 99 * 100 / 2;
226        assert_eq!(expect, actual);
227    }
228
229    #[tokio::test]
230    async fn for_each_max_concurrency() {
231        #[derive(Default, Debug)]
232        struct Jobs {
233            max: AtomicUsize,
234            curr: AtomicUsize,
235        }
236
237        let jobs = Arc::new(Jobs::default());
238
239        let result = stream::iter(0..32)
240            .try_for_each_spawned(4, |_| {
241                let jobs = jobs.clone();
242                async move {
243                    jobs.curr.fetch_add(1, Ordering::Relaxed);
244                    tokio::time::sleep(Duration::from_millis(100)).await;
245                    let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
246                    jobs.max.fetch_max(prev, Ordering::Relaxed);
247                    Ok::<(), ()>(())
248                }
249            })
250            .await;
251
252        assert!(result.is_ok());
253
254        let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
255        assert_eq!(curr.into_inner(), 0);
256        assert!(max.into_inner() <= 4);
257    }
258
259    #[tokio::test]
260    async fn for_each_error_propagation() {
261        let actual = Arc::new(Mutex::new(vec![]));
262        let result = stream::iter(0..100)
263            .try_for_each_spawned(None, |i| {
264                let actual = actual.clone();
265                async move {
266                    if i < 42 {
267                        actual.lock().unwrap().push(i);
268                        Ok(())
269                    } else {
270                        Err(())
271                    }
272                }
273            })
274            .await;
275
276        assert!(result.is_err());
277
278        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
279        let expect: Vec<_> = (0..42).collect();
280        assert_eq!(expect, actual);
281    }
282
283    #[tokio::test]
284    #[should_panic]
285    async fn for_each_panic_propagation() {
286        let _ = stream::iter(0..100)
287            .try_for_each_spawned(None, |i| async move {
288                assert!(i < 42);
289                Ok::<(), ()>(())
290            })
291            .await;
292    }
293}