sui_indexer_alt_framework/
task.rs1use std::{future::Future, panic, pin::pin, time::Duration};
5
6use futures::stream::{Stream, StreamExt};
7use tokio::{task::JoinSet, time::sleep};
8
9pub trait TrySpawnStreamExt: Stream {
11 fn try_for_each_spawned<Fut, F, E>(
28 self,
29 limit: impl Into<Option<usize>>,
30 f: F,
31 ) -> impl Future<Output = Result<(), E>>
32 where
33 Fut: Future<Output = Result<(), E>> + Send + 'static,
34 F: FnMut(Self::Item) -> Fut,
35 E: Send + 'static;
36}
37
38impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
39 async fn try_for_each_spawned<Fut, F, E>(
40 self,
41 limit: impl Into<Option<usize>>,
42 mut f: F,
43 ) -> Result<(), E>
44 where
45 Fut: Future<Output = Result<(), E>> + Send + 'static,
46 F: FnMut(Self::Item) -> Fut,
47 E: Send + 'static,
48 {
49 let limit = match limit.into() {
51 Some(0) | None => usize::MAX,
52 Some(n) => n,
53 };
54
55 let mut permits = limit;
57 let mut join_set = JoinSet::new();
59 let mut draining = false;
61 let mut error = None;
63
64 let mut self_ = pin!(self);
65
66 loop {
67 tokio::select! {
68 next = self_.next(), if !draining && permits > 0 => {
69 if let Some(item) = next {
70 permits -= 1;
71 join_set.spawn(f(item));
72 } else {
73 draining = true;
77 }
78 }
79
80 Some(res) = join_set.join_next() => {
81 match res {
82 Ok(Err(e)) if error.is_none() => {
83 error = Some(e);
84 permits += 1;
85 draining = true;
86 }
87
88 Ok(_) => permits += 1,
89
90 Err(e) if e.is_panic() => {
92 panic::resume_unwind(e.into_panic())
93 }
94
95 Err(e) => {
100 assert!(e.is_cancelled());
101 permits += 1;
102 draining = true;
103 }
104 }
105 }
106
107 else => {
108 if permits == limit && draining {
111 break;
112 }
113 }
114 }
115 }
116
117 if let Some(e) = error { Err(e) } else { Ok(()) }
118 }
119}
120
121pub async fn with_slow_future_monitor<F, C>(
127 future: F,
128 threshold: Duration,
129 callback: C,
130) -> F::Output
131where
132 F: Future,
133 C: FnOnce(),
134{
135 tokio::pin!(future);
137
138 tokio::select! {
139 result = &mut future => {
140 return result;
142 }
143 _ = sleep(threshold) => {
144 callback();
146 }
147 }
148
149 future.await
152}
153
154#[cfg(test)]
155mod tests {
156 use std::{
157 sync::{
158 Arc, Mutex,
159 atomic::{AtomicUsize, Ordering},
160 },
161 time::Duration,
162 };
163
164 use futures::stream;
165 use tokio::time::timeout;
166
167 use super::*;
168
169 #[derive(Clone)]
170 struct Counter(Arc<AtomicUsize>);
171
172 impl Counter {
173 fn new() -> Self {
174 Self(Arc::new(AtomicUsize::new(0)))
175 }
176
177 fn increment(&self) {
178 self.0.fetch_add(1, Ordering::Relaxed);
179 }
180
181 fn count(&self) -> usize {
182 self.0.load(Ordering::Relaxed)
183 }
184 }
185
186 #[tokio::test]
187 async fn for_each_explicit_sequential_iteration() {
188 let actual = Arc::new(Mutex::new(vec![]));
189 let result = stream::iter(0..20)
190 .try_for_each_spawned(1, |i| {
191 let actual = actual.clone();
192 async move {
193 tokio::time::sleep(Duration::from_millis(20 - i)).await;
194 actual.lock().unwrap().push(i);
195 Ok::<(), ()>(())
196 }
197 })
198 .await;
199
200 assert!(result.is_ok());
201
202 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
203 let expect: Vec<_> = (0..20).collect();
204 assert_eq!(expect, actual);
205 }
206
207 #[tokio::test]
208 async fn for_each_concurrent_iteration() {
209 let actual = Arc::new(AtomicUsize::new(0));
210 let result = stream::iter(0..100)
211 .try_for_each_spawned(16, |i| {
212 let actual = actual.clone();
213 async move {
214 actual.fetch_add(i, Ordering::Relaxed);
215 Ok::<(), ()>(())
216 }
217 })
218 .await;
219
220 assert!(result.is_ok());
221
222 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
223 let expect = 99 * 100 / 2;
224 assert_eq!(expect, actual);
225 }
226
227 #[tokio::test]
228 async fn for_each_implicit_unlimited_iteration() {
229 let actual = Arc::new(AtomicUsize::new(0));
230 let result = stream::iter(0..100)
231 .try_for_each_spawned(None, |i| {
232 let actual = actual.clone();
233 async move {
234 actual.fetch_add(i, Ordering::Relaxed);
235 Ok::<(), ()>(())
236 }
237 })
238 .await;
239
240 assert!(result.is_ok());
241
242 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
243 let expect = 99 * 100 / 2;
244 assert_eq!(expect, actual);
245 }
246
247 #[tokio::test]
248 async fn for_each_explicit_unlimited_iteration() {
249 let actual = Arc::new(AtomicUsize::new(0));
250 let result = stream::iter(0..100)
251 .try_for_each_spawned(0, |i| {
252 let actual = actual.clone();
253 async move {
254 actual.fetch_add(i, Ordering::Relaxed);
255 Ok::<(), ()>(())
256 }
257 })
258 .await;
259
260 assert!(result.is_ok());
261
262 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
263 let expect = 99 * 100 / 2;
264 assert_eq!(expect, actual);
265 }
266
267 #[tokio::test]
268 async fn for_each_max_concurrency() {
269 #[derive(Default, Debug)]
270 struct Jobs {
271 max: AtomicUsize,
272 curr: AtomicUsize,
273 }
274
275 let jobs = Arc::new(Jobs::default());
276
277 let result = stream::iter(0..32)
278 .try_for_each_spawned(4, |_| {
279 let jobs = jobs.clone();
280 async move {
281 jobs.curr.fetch_add(1, Ordering::Relaxed);
282 tokio::time::sleep(Duration::from_millis(100)).await;
283 let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
284 jobs.max.fetch_max(prev, Ordering::Relaxed);
285 Ok::<(), ()>(())
286 }
287 })
288 .await;
289
290 assert!(result.is_ok());
291
292 let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
293 assert_eq!(curr.into_inner(), 0);
294 assert!(max.into_inner() <= 4);
295 }
296
297 #[tokio::test]
298 async fn for_each_error_propagation() {
299 let actual = Arc::new(Mutex::new(vec![]));
300 let result = stream::iter(0..100)
301 .try_for_each_spawned(None, |i| {
302 let actual = actual.clone();
303 async move {
304 if i < 42 {
305 actual.lock().unwrap().push(i);
306 Ok(())
307 } else {
308 Err(())
309 }
310 }
311 })
312 .await;
313
314 assert!(result.is_err());
315
316 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
317 let expect: Vec<_> = (0..42).collect();
318 assert_eq!(expect, actual);
319 }
320
321 #[tokio::test]
322 #[should_panic]
323 async fn for_each_panic_propagation() {
324 let _ = stream::iter(0..100)
325 .try_for_each_spawned(None, |i| async move {
326 assert!(i < 42);
327 Ok::<(), ()>(())
328 })
329 .await;
330 }
331
332 #[tokio::test]
333 async fn slow_monitor_callback_called_once_when_threshold_exceeded() {
334 let c = Counter::new();
335
336 let result = with_slow_future_monitor(
337 async {
338 sleep(Duration::from_millis(200)).await;
339 42 },
341 Duration::from_millis(100),
342 || c.increment(),
343 )
344 .await;
345
346 assert_eq!(c.count(), 1);
347 assert_eq!(result, 42);
348 }
349
350 #[tokio::test]
351 async fn slow_monitor_callback_not_called_when_threshold_not_exceeded() {
352 let c = Counter::new();
353
354 let result = with_slow_future_monitor(
355 async {
356 sleep(Duration::from_millis(50)).await;
357 42 },
359 Duration::from_millis(200),
360 || c.increment(),
361 )
362 .await;
363
364 assert_eq!(c.count(), 0);
365 assert_eq!(result, 42);
366 }
367
368 #[tokio::test]
369 async fn slow_monitor_error_propagation() {
370 let c = Counter::new();
371
372 let result: Result<i32, &str> = with_slow_future_monitor(
373 async {
374 sleep(Duration::from_millis(150)).await;
375 Err("Something went wrong")
376 },
377 Duration::from_millis(100),
378 || c.increment(),
379 )
380 .await;
381
382 assert!(result.is_err());
383 assert_eq!(result.unwrap_err(), "Something went wrong");
384 assert_eq!(c.count(), 1);
385 }
386
387 #[tokio::test]
388 async fn slow_monitor_error_propagation_without_callback() {
389 let c = Counter::new();
390
391 let result: Result<i32, &str> = with_slow_future_monitor(
392 async {
393 sleep(Duration::from_millis(50)).await;
394 Err("Quick error")
395 },
396 Duration::from_millis(200),
397 || c.increment(),
398 )
399 .await;
400
401 assert!(result.is_err());
402 assert_eq!(result.unwrap_err(), "Quick error");
403 assert_eq!(c.count(), 0);
404 }
405
406 #[tokio::test]
407 async fn slow_monitor_stuck_future_detection() {
408 use std::future::Future;
409 use std::pin::Pin;
410 use std::task::{Context, Poll};
411
412 struct StuckFuture;
414 impl Future for StuckFuture {
415 type Output = ();
416 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
417 Poll::Pending
418 }
419 }
420
421 let c = Counter::new();
422
423 let monitored =
425 with_slow_future_monitor(StuckFuture, Duration::from_millis(200), || c.increment());
426
427 timeout(Duration::from_secs(2), monitored)
429 .await
430 .unwrap_err();
431 assert_eq!(c.count(), 1);
432 }
433}