1use std::{future::Future, panic, pin::pin};
5
6use futures::stream::{Stream, StreamExt};
7use tokio::task::JoinSet;
8
9pub trait TrySpawnStreamExt: Stream {
11 fn try_for_each_spawned<Fut, F, E>(
28 self,
29 limit: impl Into<Option<usize>>,
30 f: F,
31 ) -> impl Future<Output = Result<(), E>>
32 where
33 Fut: Future<Output = Result<(), E>> + Send + 'static,
34 F: FnMut(Self::Item) -> Fut,
35 E: Send + 'static;
36}
37
38pub enum Break<E> {
41 Break,
42 Err(E),
43}
44
45impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
46 async fn try_for_each_spawned<Fut, F, E>(
47 self,
48 limit: impl Into<Option<usize>>,
49 mut f: F,
50 ) -> Result<(), E>
51 where
52 Fut: Future<Output = Result<(), E>> + Send + 'static,
53 F: FnMut(Self::Item) -> Fut,
54 E: Send + 'static,
55 {
56 let limit = match limit.into() {
58 Some(0) | None => usize::MAX,
59 Some(n) => n,
60 };
61
62 let mut permits = limit;
64 let mut join_set = JoinSet::new();
66 let mut draining = false;
68 let mut error = None;
70
71 let mut self_ = pin!(self);
72
73 loop {
74 tokio::select! {
75 next = self_.next(), if !draining && permits > 0 => {
76 if let Some(item) = next {
77 permits -= 1;
78 join_set.spawn(f(item));
79 } else {
80 draining = true;
84 }
85 }
86
87 Some(res) = join_set.join_next() => {
88 match res {
89 Ok(Err(e)) if error.is_none() => {
90 error = Some(e);
91 permits += 1;
92 draining = true;
93 }
94
95 Ok(_) => permits += 1,
96
97 Err(e) if e.is_panic() => {
99 panic::resume_unwind(e.into_panic())
100 }
101
102 Err(e) => {
107 assert!(e.is_cancelled());
108 permits += 1;
109 draining = true;
110 }
111 }
112 }
113
114 else => {
115 if permits == limit && draining {
118 break;
119 }
120 }
121 }
122 }
123
124 if let Some(e) = error { Err(e) } else { Ok(()) }
125 }
126}
127
128impl<E> From<E> for Break<E> {
129 fn from(e: E) -> Self {
130 Break::Err(e)
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::{
137 sync::{
138 Arc, Mutex,
139 atomic::{AtomicUsize, Ordering},
140 },
141 time::Duration,
142 };
143
144 use futures::stream;
145
146 use super::*;
147
148 #[tokio::test]
149 async fn for_each_explicit_sequential_iteration() {
150 let actual = Arc::new(Mutex::new(vec![]));
151 let result = stream::iter(0..20)
152 .try_for_each_spawned(1, |i| {
153 let actual = actual.clone();
154 async move {
155 tokio::time::sleep(Duration::from_millis(20 - i)).await;
156 actual.lock().unwrap().push(i);
157 Ok::<(), ()>(())
158 }
159 })
160 .await;
161
162 assert!(result.is_ok());
163
164 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
165 let expect: Vec<_> = (0..20).collect();
166 assert_eq!(expect, actual);
167 }
168
169 #[tokio::test]
170 async fn for_each_concurrent_iteration() {
171 let actual = Arc::new(AtomicUsize::new(0));
172 let result = stream::iter(0..100)
173 .try_for_each_spawned(16, |i| {
174 let actual = actual.clone();
175 async move {
176 actual.fetch_add(i, Ordering::Relaxed);
177 Ok::<(), ()>(())
178 }
179 })
180 .await;
181
182 assert!(result.is_ok());
183
184 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
185 let expect = 99 * 100 / 2;
186 assert_eq!(expect, actual);
187 }
188
189 #[tokio::test]
190 async fn for_each_implicit_unlimited_iteration() {
191 let actual = Arc::new(AtomicUsize::new(0));
192 let result = stream::iter(0..100)
193 .try_for_each_spawned(None, |i| {
194 let actual = actual.clone();
195 async move {
196 actual.fetch_add(i, Ordering::Relaxed);
197 Ok::<(), ()>(())
198 }
199 })
200 .await;
201
202 assert!(result.is_ok());
203
204 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
205 let expect = 99 * 100 / 2;
206 assert_eq!(expect, actual);
207 }
208
209 #[tokio::test]
210 async fn for_each_explicit_unlimited_iteration() {
211 let actual = Arc::new(AtomicUsize::new(0));
212 let result = stream::iter(0..100)
213 .try_for_each_spawned(0, |i| {
214 let actual = actual.clone();
215 async move {
216 actual.fetch_add(i, Ordering::Relaxed);
217 Ok::<(), ()>(())
218 }
219 })
220 .await;
221
222 assert!(result.is_ok());
223
224 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
225 let expect = 99 * 100 / 2;
226 assert_eq!(expect, actual);
227 }
228
229 #[tokio::test]
230 async fn for_each_max_concurrency() {
231 #[derive(Default, Debug)]
232 struct Jobs {
233 max: AtomicUsize,
234 curr: AtomicUsize,
235 }
236
237 let jobs = Arc::new(Jobs::default());
238
239 let result = stream::iter(0..32)
240 .try_for_each_spawned(4, |_| {
241 let jobs = jobs.clone();
242 async move {
243 jobs.curr.fetch_add(1, Ordering::Relaxed);
244 tokio::time::sleep(Duration::from_millis(100)).await;
245 let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
246 jobs.max.fetch_max(prev, Ordering::Relaxed);
247 Ok::<(), ()>(())
248 }
249 })
250 .await;
251
252 assert!(result.is_ok());
253
254 let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
255 assert_eq!(curr.into_inner(), 0);
256 assert!(max.into_inner() <= 4);
257 }
258
259 #[tokio::test]
260 async fn for_each_error_propagation() {
261 let actual = Arc::new(Mutex::new(vec![]));
262 let result = stream::iter(0..100)
263 .try_for_each_spawned(None, |i| {
264 let actual = actual.clone();
265 async move {
266 if i < 42 {
267 actual.lock().unwrap().push(i);
268 Ok(())
269 } else {
270 Err(())
271 }
272 }
273 })
274 .await;
275
276 assert!(result.is_err());
277
278 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
279 let expect: Vec<_> = (0..42).collect();
280 assert_eq!(expect, actual);
281 }
282
283 #[tokio::test]
284 #[should_panic]
285 async fn for_each_panic_propagation() {
286 let _ = stream::iter(0..100)
287 .try_for_each_spawned(None, |i| async move {
288 assert!(i < 42);
289 Ok::<(), ()>(())
290 })
291 .await;
292 }
293}