sui_futures/
stream.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::future::Future;
5use std::future::poll_fn;
6use std::panic;
7use std::pin::pin;
8use std::sync::Arc;
9
10use futures::FutureExt;
11use futures::future::try_join_all;
12use futures::stream::Stream;
13use futures::try_join;
14use tokio::sync::mpsc;
15use tokio::task::JoinSet;
16
17/// Runtime configuration for adaptive concurrency control.
18///
19/// For fixed concurrency, use [`ConcurrencyConfig::fixed`].
20/// For adaptive concurrency, use [`ConcurrencyConfig::adaptive`] which requires the three core
21/// parameters (initial, min, max). The dead-band thresholds have sensible defaults and can be
22/// overridden with a chainable setter:
23///
24/// ```ignore
25/// ConcurrencyConfig::adaptive(10, 1, 20)
26///     .with_dead_band(0.5, 0.9)
27/// ```
28#[derive(Debug, Clone)]
29pub struct ConcurrencyConfig {
30    pub initial: usize,
31    pub min: usize,
32    pub max: usize,
33    /// Fill fraction below which the controller may increase the limit (if saturated).
34    /// Default: 0.6.
35    pub dead_band_low: f64,
36    /// Fill fraction at or above which the controller decreases the limit. Default: 0.85.
37    pub dead_band_high: f64,
38}
39
40/// Snapshot of concurrency stats passed to the `report` callback.
41#[derive(Debug, Clone, Copy)]
42pub struct ConcurrencyStats {
43    pub limit: usize,
44    pub inflight: usize,
45}
46
47/// Wrapper type for errors to allow the body of a `try_for_each_spawned` call to signal that it
48/// either wants to return early (`Break`) out of the loop, or propagate an error (`Err(E)`).
49#[derive(Debug)]
50pub enum Break<E> {
51    Break,
52    Err(E),
53}
54
55/// Extension trait introducing `try_for_each_spawned` to all streams.
56pub trait TrySpawnStreamExt: Stream {
57    /// Attempts to run this stream to completion, executing the provided asynchronous closure on
58    /// each element from the stream as elements become available.
59    ///
60    /// This is similar to [`futures::stream::StreamExt::for_each_concurrent`], but it may take advantage of any
61    /// parallelism available in the underlying runtime, because each unit of work is spawned as
62    /// its own tokio task.
63    ///
64    /// The first argument is an optional limit on the number of tasks to spawn concurrently.
65    /// Values of `0` and `None` are interpreted as no limit, and any other value will result in no
66    /// more than that many tasks being spawned at one time.
67    ///
68    /// ## Safety
69    ///
70    /// This function will panic if any of its futures panics, will return early with success if
71    /// the runtime it is running on is cancelled, and will return early with an error propagated
72    /// from any worker that produces an error.
73    fn try_for_each_spawned<Fut, F, E>(
74        self,
75        limit: impl Into<Option<usize>>,
76        f: F,
77    ) -> impl Future<Output = Result<(), E>>
78    where
79        Fut: Future<Output = Result<(), E>> + Send + 'static,
80        F: FnMut(Self::Item) -> Fut,
81        E: Send + 'static;
82
83    /// Process each stream item through a spawned task, sending results to a single channel.
84    ///
85    /// Each item is passed to `f` which returns a future producing `Result<T, Break<E>>`. The
86    /// resulting `T` is sent to `tx`. Concurrency is controlled by `config`: for fixed configs,
87    /// the limit never changes; for adaptive configs, the limit adjusts based on the fill fraction
88    /// of the output channel.
89    ///
90    /// Unlike [`try_for_each_broadcast_spawned`](TrySpawnStreamExt::try_for_each_broadcast_spawned),
91    /// `T` does not need to be `Clone` since there is only a single receiver.
92    ///
93    /// The `report` callback is invoked each iteration with concurrency stats for metrics.
94    fn try_for_each_send_spawned<Fut, F, T, E, R>(
95        self,
96        config: ConcurrencyConfig,
97        f: F,
98        tx: mpsc::Sender<T>,
99        report: R,
100    ) -> impl Future<Output = Result<(), Break<E>>>
101    where
102        Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
103        F: FnMut(Self::Item) -> Fut,
104        T: Send + 'static,
105        E: Send + 'static,
106        R: Fn(ConcurrencyStats);
107
108    /// Process each stream item through a spawned task, broadcasting results to multiple channels.
109    ///
110    /// Same as [`try_for_each_send_spawned`](TrySpawnStreamExt::try_for_each_send_spawned) but
111    /// sends a clone of each result to every channel in `txs`. Fill fraction is measured as the
112    /// maximum across all channels. Requires `T: Clone` since values are cloned to each receiver.
113    fn try_for_each_broadcast_spawned<Fut, F, T, E, R>(
114        self,
115        config: ConcurrencyConfig,
116        f: F,
117        txs: Vec<mpsc::Sender<T>>,
118        report: R,
119    ) -> impl Future<Output = Result<(), Break<E>>>
120    where
121        Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
122        F: FnMut(Self::Item) -> Fut,
123        T: Clone + Send + Sync + 'static,
124        E: Send + 'static,
125        R: Fn(ConcurrencyStats);
126}
127
128/// Abstraction over single-channel and broadcast sending so that
129/// `adaptive_spawn_send` can be generic over both. This avoids requiring `T: Clone`
130/// on the single-sender path (`SingleSender` moves the value without cloning).
131///
132/// Implementors must be cheaply cloneable (they are cloned for each spawned task).
133trait Sender: Clone + Send + Sync + 'static {
134    type Value: Send + 'static;
135
136    /// Send a value downstream. Returns `Err(())` if the channel(s) are closed.
137    fn send(&self, value: Self::Value) -> impl Future<Output = Result<(), ()>> + Send;
138
139    /// Measure the fill fraction of the downstream channel(s).
140    /// Returns a value in [0.0, 1.0] where 1.0 means completely full.
141    fn fill(&self) -> f64;
142}
143
144/// Single-channel sender.
145struct SingleSender<T>(mpsc::Sender<T>);
146
147/// Broadcast sender that clones the value to all channels.
148struct BroadcastSender<T>(Arc<Vec<mpsc::Sender<T>>>);
149
150impl ConcurrencyConfig {
151    pub fn fixed(n: usize) -> Self {
152        Self {
153            initial: n,
154            min: n,
155            max: n,
156            dead_band_low: 0.6,
157            dead_band_high: 0.85,
158        }
159    }
160
161    pub fn adaptive(initial: usize, min: usize, max: usize) -> Self {
162        Self {
163            initial,
164            min,
165            max,
166            dead_band_low: 0.6,
167            dead_band_high: 0.85,
168        }
169    }
170
171    pub fn with_dead_band(mut self, low: f64, high: f64) -> Self {
172        self.dead_band_low = low;
173        self.dead_band_high = high;
174        self
175    }
176}
177
178impl<E> From<E> for Break<E> {
179    fn from(e: E) -> Self {
180        Break::Err(e)
181    }
182}
183
184impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
185    async fn try_for_each_spawned<Fut, F, E>(
186        self,
187        limit: impl Into<Option<usize>>,
188        mut f: F,
189    ) -> Result<(), E>
190    where
191        Fut: Future<Output = Result<(), E>> + Send + 'static,
192        F: FnMut(Self::Item) -> Fut,
193        E: Send + 'static,
194    {
195        // Maximum number of tasks to spawn concurrently.
196        let limit = match limit.into() {
197            Some(0) | None => usize::MAX,
198            Some(n) => n,
199        };
200
201        // Number of permits to spawn tasks left.
202        let mut permits = limit;
203        // Handles for already spawned tasks.
204        let mut join_set = JoinSet::new();
205        // Whether the worker pool has stopped accepting new items and is draining.
206        let mut draining = false;
207        // Error that occurred in one of the workers, to be propagated to the called on exit.
208        let mut error = None;
209
210        let mut self_ = pin!(self);
211
212        loop {
213            // Eager inner loop: spawn tasks while permits allow and items are ready,
214            // avoiding select! overhead when items are immediately available.
215            while !draining && permits > 0 {
216                match poll_fn(|cx| self_.as_mut().poll_next(cx)).now_or_never() {
217                    Some(Some(item)) => {
218                        permits -= 1;
219                        join_set.spawn(f(item));
220                    }
221                    Some(None) => {
222                        // If the stream is empty, signal that the worker pool is going to
223                        // start draining now, so that once we get all our permits back, we
224                        // know we can wind down the pool.
225                        draining = true;
226                    }
227                    None => break,
228                }
229            }
230
231            tokio::select! {
232                biased;
233
234                Some(res) = join_set.join_next() => {
235                    match res {
236                        Ok(Err(e)) if error.is_none() => {
237                            error = Some(e);
238                            permits += 1;
239                            draining = true;
240                        }
241
242                        Ok(_) => permits += 1,
243
244                        // Worker panicked, propagate the panic.
245                        Err(e) if e.is_panic() => {
246                            panic::resume_unwind(e.into_panic())
247                        }
248
249                        // Worker was cancelled -- this can only happen if its join handle was
250                        // cancelled (not possible because that was created in this function),
251                        // or the runtime it was running in was wound down, in which case,
252                        // prepare the worker pool to drain.
253                        Err(e) => {
254                            assert!(e.is_cancelled());
255                            permits += 1;
256                            draining = true;
257                        }
258                    }
259                }
260
261                next = poll_fn(|cx| self_.as_mut().poll_next(cx)),
262                    if !draining && permits > 0 => {
263                    if let Some(item) = next {
264                        permits -= 1;
265                        join_set.spawn(f(item));
266                    } else {
267                        draining = true;
268                    }
269                }
270
271                else => {
272                    if permits == limit && draining {
273                        break;
274                    }
275                }
276            }
277        }
278
279        if let Some(e) = error { Err(e) } else { Ok(()) }
280    }
281
282    async fn try_for_each_send_spawned<Fut, F, T, E, R>(
283        self,
284        config: ConcurrencyConfig,
285        f: F,
286        tx: mpsc::Sender<T>,
287        report: R,
288    ) -> Result<(), Break<E>>
289    where
290        Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
291        F: FnMut(Self::Item) -> Fut,
292        T: Send + 'static,
293        E: Send + 'static,
294        R: Fn(ConcurrencyStats),
295    {
296        adaptive_spawn_send(self, config, f, SingleSender(tx), report).await
297    }
298
299    async fn try_for_each_broadcast_spawned<Fut, F, T, E, R>(
300        self,
301        config: ConcurrencyConfig,
302        f: F,
303        txs: Vec<mpsc::Sender<T>>,
304        report: R,
305    ) -> Result<(), Break<E>>
306    where
307        Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
308        F: FnMut(Self::Item) -> Fut,
309        T: Clone + Send + Sync + 'static,
310        E: Send + 'static,
311        R: Fn(ConcurrencyStats),
312    {
313        adaptive_spawn_send(self, config, f, BroadcastSender(Arc::new(txs)), report).await
314    }
315}
316
317impl<T> Clone for SingleSender<T> {
318    fn clone(&self) -> Self {
319        Self(self.0.clone())
320    }
321}
322
323impl<T: Send + 'static> Sender for SingleSender<T> {
324    type Value = T;
325
326    async fn send(&self, value: T) -> Result<(), ()> {
327        self.0.send(value).await.map_err(|_| ())
328    }
329
330    fn fill(&self) -> f64 {
331        1.0 - (self.0.capacity() as f64 / self.0.max_capacity() as f64)
332    }
333}
334
335impl<T> Clone for BroadcastSender<T> {
336    fn clone(&self) -> Self {
337        Self(self.0.clone())
338    }
339}
340
341impl<T: Clone + Send + Sync + 'static> Sender for BroadcastSender<T> {
342    type Value = T;
343
344    async fn send(&self, value: T) -> Result<(), ()> {
345        let (last, rest) = self.0.split_last().ok_or(())?;
346        let rest_fut = try_join_all(rest.iter().map(|tx| {
347            let v = value.clone();
348            async move { tx.send(v).await.map_err(|_| ()) }
349        }));
350        let last_fut = last.send(value).map(|r| r.map_err(|_| ()));
351        try_join!(rest_fut, last_fut)?;
352        Ok(())
353    }
354
355    fn fill(&self) -> f64 {
356        self.0
357            .iter()
358            .map(|tx| 1.0 - (tx.capacity() as f64 / tx.max_capacity() as f64))
359            .fold(0.0f64, f64::max)
360    }
361}
362
363/// Shared adaptive concurrency loop used by both `try_for_each_send_spawned` and
364/// `try_for_each_broadcast_spawned`.
365///
366/// The algorithm is a result of weeks of trial and error guided by lots of Claude Research queries.
367/// I tried to document where each of the ideas came from in the footnotes of this comment.
368///
369/// # Adaptive concurrency control
370///
371/// When `config.min < config.max`, this function dynamically adjusts how many tasks run
372/// concurrently based on how full the downstream channel(s) are. The goal is to find and
373/// hold a stable concurrency limit that keeps downstream fed without using excessive memory.
374///
375/// ## The core problem
376///
377/// N tasks share one set of output channels. If we naively cut the limit every time any task
378/// sees a full channel, all N tasks observe the same fullness at once and each triggers a
379/// cut — collapsing the limit exponentially (`limit * ratio^N`) instead of once. This causes
380/// wild oscillation: the limit crashes to near-zero then slowly climbs back, wasting ~50%
381/// of throughput.
382///
383/// ## How this controller avoids that
384///
385/// **Epoch-gated reductions** — Each task captures the current `epoch` when spawned. A
386/// reduction increments the epoch. If a completing task's epoch is stale (from before the
387/// last reduction), its congestion signal is ignored. This ensures only one reduction fires
388/// per congestion event, no matter how many tasks see it. [1]
389///
390/// **Severity-scaled cuts** — Instead of "full → cut by fixed ratio", a `severity` score
391/// derived from the fill fraction drives the size of the cut linearly:
392/// `keep = 0.8 - 0.3 * severity` where `severity = (fill - 0.85) / 0.15`.
393/// At fill = 0.85 (threshold), severity = 0 and the controller keeps 80% (gentle 20%
394/// cut). At fill = 1.0 (hard saturation), severity = 1 and the controller keeps 50%
395/// (aggressive halving). [2]
396///
397/// **Guaranteed progress** — After computing the proportional cut, the new limit is capped
398/// at `limit - 1` so that `ceil()` rounding cannot stall decreases at small values.
399///
400/// **Dead band between increase and decrease** — Three zones based on fill fraction:
401/// - `fill >= 0.85`: decrease (proportionally, epoch-gated)
402/// - `fill < 0.60`: increase (if the limit was actually being used)
403/// - `0.60–0.85`: do nothing
404///
405/// The gap between the increase and decrease thresholds prevents the controller from
406/// endlessly flip-flopping at the boundary — it finds a stable operating point and stays
407/// there. [3]
408///
409/// **Log10-scaled increase** — The limit grows by `ceil(log10(limit))` instead of +1. This
410/// scales sub-linearly: recovery isn't painfully slow at high limits (e.g. +3 at limit=200
411/// vs +1), but it's slower than proportional growth so it doesn't overshoot. [4]
412///
413/// **Saturation guard** — The limit only increases when inflight actually reached the limit
414/// (`was_saturated`). Without this, the limit would grow unboundedly during low-load periods
415/// since fill stays low and every completion triggers an increase — even though the extra
416/// concurrency was never actually used or tested against real backpressure. When load
417/// eventually does arrive, the inflated limit allows a burst that overwhelms the channel.
418/// Crucially, `was_saturated` resets on *both* increase and decrease. The decrease-side
419/// reset acts as a cooling-off period: after cutting the limit, the controller must see
420/// the new lower limit fully utilized before raising it again. Without this, the old
421/// saturation proof carries over, and the increase branch fires on the very next low-fill
422/// completion — undoing the decrease and causing oscillation.
423///
424/// ---
425/// [1] Same idea as TCP NewReno's `recover` variable (RFC 6582): record a high-water mark
426///     when congestion is detected, suppress further reductions until tasks past that mark
427///     start completing.
428/// [2] Analogous to DCTCP's proportional window reduction (RFC 8257):
429///     `cwnd = cwnd * (1 - α/2)` where α is the marked fraction.
430/// [3] TCP Vegas uses the same approach with alpha/beta thresholds around estimated queue
431///     depth to avoid oscillation at the operating point.
432/// [4] Netflix's `concurrency-limits` VegasLimit uses the same `log10(limit)` additive
433///     increase with a `max(1)` floor.
434async fn adaptive_spawn_send<S, Fut, F, E, Tx, R>(
435    stream: S,
436    config: ConcurrencyConfig,
437    mut f: F,
438    sender: Tx,
439    report: R,
440) -> Result<(), Break<E>>
441where
442    S: Stream + 'static,
443    Fut: Future<Output = Result<Tx::Value, Break<E>>> + Send + 'static,
444    F: FnMut(S::Item) -> Fut,
445    E: Send + 'static,
446    Tx: Sender,
447    R: Fn(ConcurrencyStats),
448{
449    assert!(config.min >= 1, "ConcurrencyConfig::min must be >= 1");
450    let mut limit = config.initial;
451    let mut epoch: u64 = 0;
452    let mut was_saturated = false;
453    let mut tasks: JoinSet<Result<u64, Break<E>>> = JoinSet::new();
454    let mut stream_done = false;
455    let mut error: Option<Break<E>> = None;
456
457    let mut stream = pin!(stream);
458
459    loop {
460        if tasks.is_empty() && (stream_done || error.is_some()) {
461            break;
462        }
463
464        // Eager inner loop: spawn tasks while under the limit and items are ready.
465        while tasks.len() < limit && !stream_done && error.is_none() {
466            match poll_fn(|cx| stream.as_mut().poll_next(cx)).now_or_never() {
467                Some(Some(item)) => {
468                    let fut = f(item);
469                    let tx = sender.clone();
470                    let spawn_epoch = epoch;
471                    tasks.spawn(async move {
472                        let value = fut.await?;
473                        tx.send(value).await.map_err(|_| Break::Break)?;
474                        Ok(spawn_epoch)
475                    });
476                    if tasks.len() >= limit {
477                        was_saturated = true;
478                    }
479                }
480                Some(None) => stream_done = true,
481                None => break,
482            }
483        }
484
485        let completed = tokio::select! {
486            biased;
487
488            Some(r) = tasks.join_next(), if !tasks.is_empty() => Some(r),
489
490            next = poll_fn(|cx| stream.as_mut().poll_next(cx)),
491                if tasks.len() < limit && !stream_done && error.is_none() =>
492            {
493                if let Some(item) = next {
494                    let fut = f(item);
495                    let tx = sender.clone();
496                    let spawn_epoch = epoch;
497                    tasks.spawn(async move {
498                        let value = fut.await?;
499                        tx.send(value).await.map_err(|_| Break::Break)?;
500                        Ok(spawn_epoch)
501                    });
502                    if tasks.len() >= limit {
503                        was_saturated = true;
504                    }
505                } else {
506                    stream_done = true;
507                }
508                None
509            }
510
511            else => {
512                if tasks.is_empty() && (stream_done || error.is_some()) {
513                    break;
514                }
515                None
516            }
517        };
518
519        // Handle all completions: the one from select (if any) + drain ready ones.
520        for join_result in completed.into_iter().chain(std::iter::from_fn(|| {
521            tasks.join_next().now_or_never().flatten()
522        })) {
523            match join_result {
524                Ok(Ok(spawn_epoch)) => {
525                    // Adjust concurrency limit based on channel fill fraction.
526                    // - fill >= dead_band_high: severity-scaled decrease (epoch-gated)
527                    // - fill < dead_band_low and limit was saturated: log10-scaled increase
528                    // - fill in [dead_band_low, dead_band_high): hold steady
529                    // When config is fixed (min == max), clamp/min keep limit unchanged.
530                    let fill = sender.fill();
531                    if fill >= config.dead_band_high && spawn_epoch == epoch {
532                        // Proportional cut that gets aggressive near saturation.
533                        // At dead_band_high: keep 80% (cut 20%)
534                        // At midpoint:       keep 65% (cut 35%)
535                        // At 1.00:           keep 50% (cut 50%)
536                        let severity =
537                            (fill - config.dead_band_high) / (1.0 - config.dead_band_high);
538                        let keep = 0.8 - 0.3 * severity;
539                        let new_limit = ((limit as f64) * keep).ceil() as usize;
540                        limit = new_limit.min(limit.saturating_sub(1)).max(config.min);
541                        limit = limit.clamp(config.min, config.max);
542                        epoch += 1;
543                        // Reset was_saturated so the controller must re-prove the new
544                        // lower limit is fully utilized before increasing again. This
545                        // cooling-off period prevents decrease→increase oscillation:
546                        // without it, the old saturation proof carries over and the
547                        // increase branch fires immediately on the next low-fill
548                        // completion, undoing the decrease.
549                        was_saturated = false;
550                    } else if fill < config.dead_band_low && was_saturated {
551                        let increment = ((limit as f64).log10().ceil() as usize).max(1);
552                        limit = (limit + increment).min(config.max);
553                        was_saturated = false;
554                    }
555                }
556                Ok(Err(e)) if error.is_none() => error = Some(e),
557                Ok(Err(_)) => {}
558                Err(e) if e.is_panic() => panic::resume_unwind(e.into_panic()),
559                Err(e) => {
560                    assert!(e.is_cancelled());
561                    stream_done = true;
562                }
563            }
564        }
565
566        report(ConcurrencyStats {
567            limit,
568            inflight: tasks.len(),
569        });
570    }
571
572    if let Some(e) = error { Err(e) } else { Ok(()) }
573}
574
575#[cfg(test)]
576mod tests {
577    use std::{
578        sync::{
579            Arc, Mutex,
580            atomic::{AtomicUsize, Ordering},
581        },
582        time::Duration,
583    };
584
585    use futures::stream;
586
587    use super::*;
588
589    #[tokio::test]
590    async fn for_each_explicit_sequential_iteration() {
591        let actual = Arc::new(Mutex::new(vec![]));
592        let result = stream::iter(0..20)
593            .try_for_each_spawned(1, |i| {
594                let actual = actual.clone();
595                async move {
596                    tokio::time::sleep(Duration::from_millis(20 - i)).await;
597                    actual.lock().unwrap().push(i);
598                    Ok::<(), ()>(())
599                }
600            })
601            .await;
602
603        assert!(result.is_ok());
604
605        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
606        let expect: Vec<_> = (0..20).collect();
607        assert_eq!(expect, actual);
608    }
609
610    #[tokio::test]
611    async fn for_each_concurrent_iteration() {
612        let actual = Arc::new(AtomicUsize::new(0));
613        let result = stream::iter(0..100)
614            .try_for_each_spawned(16, |i| {
615                let actual = actual.clone();
616                async move {
617                    actual.fetch_add(i, Ordering::Relaxed);
618                    Ok::<(), ()>(())
619                }
620            })
621            .await;
622
623        assert!(result.is_ok());
624
625        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
626        let expect = 99 * 100 / 2;
627        assert_eq!(expect, actual);
628    }
629
630    #[tokio::test]
631    async fn for_each_implicit_unlimited_iteration() {
632        let actual = Arc::new(AtomicUsize::new(0));
633        let result = stream::iter(0..100)
634            .try_for_each_spawned(None, |i| {
635                let actual = actual.clone();
636                async move {
637                    actual.fetch_add(i, Ordering::Relaxed);
638                    Ok::<(), ()>(())
639                }
640            })
641            .await;
642
643        assert!(result.is_ok());
644
645        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
646        let expect = 99 * 100 / 2;
647        assert_eq!(expect, actual);
648    }
649
650    #[tokio::test]
651    async fn for_each_explicit_unlimited_iteration() {
652        let actual = Arc::new(AtomicUsize::new(0));
653        let result = stream::iter(0..100)
654            .try_for_each_spawned(0, |i| {
655                let actual = actual.clone();
656                async move {
657                    actual.fetch_add(i, Ordering::Relaxed);
658                    Ok::<(), ()>(())
659                }
660            })
661            .await;
662
663        assert!(result.is_ok());
664
665        let actual = Arc::try_unwrap(actual).unwrap().into_inner();
666        let expect = 99 * 100 / 2;
667        assert_eq!(expect, actual);
668    }
669
670    #[tokio::test]
671    async fn for_each_max_concurrency() {
672        #[derive(Default, Debug)]
673        struct Jobs {
674            max: AtomicUsize,
675            curr: AtomicUsize,
676        }
677
678        let jobs = Arc::new(Jobs::default());
679
680        let result = stream::iter(0..32)
681            .try_for_each_spawned(4, |_| {
682                let jobs = jobs.clone();
683                async move {
684                    jobs.curr.fetch_add(1, Ordering::Relaxed);
685                    tokio::time::sleep(Duration::from_millis(100)).await;
686                    let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
687                    jobs.max.fetch_max(prev, Ordering::Relaxed);
688                    Ok::<(), ()>(())
689                }
690            })
691            .await;
692
693        assert!(result.is_ok());
694
695        let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
696        assert_eq!(curr.into_inner(), 0);
697        assert!(max.into_inner() <= 4);
698    }
699
700    #[tokio::test]
701    async fn for_each_error_propagation() {
702        let actual = Arc::new(Mutex::new(vec![]));
703        let result = stream::iter(0..100)
704            .try_for_each_spawned(None, |i| {
705                let actual = actual.clone();
706                async move {
707                    if i < 42 {
708                        actual.lock().unwrap().push(i);
709                        Ok(())
710                    } else {
711                        Err(())
712                    }
713                }
714            })
715            .await;
716
717        assert!(result.is_err());
718
719        let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
720        let expect: Vec<_> = (0..42).collect();
721        assert_eq!(expect, actual);
722    }
723
724    #[tokio::test]
725    #[should_panic]
726    async fn for_each_panic_propagation() {
727        let _ = stream::iter(0..100)
728            .try_for_each_spawned(None, |i| async move {
729                assert!(i < 42);
730                Ok::<(), ()>(())
731            })
732            .await;
733    }
734
735    #[tokio::test]
736    async fn send_spawned_basic() {
737        let (tx, mut rx) = mpsc::channel(100);
738        let result = stream::iter(0..10u64)
739            .try_for_each_send_spawned(
740                ConcurrencyConfig::fixed(4),
741                |i| async move { Ok::<_, Break<()>>(i * 2) },
742                tx,
743                |_| {},
744            )
745            .await;
746
747        assert!(result.is_ok());
748
749        let mut values = Vec::new();
750        while let Ok(v) = rx.try_recv() {
751            values.push(v);
752        }
753        values.sort();
754        let expected: Vec<u64> = (0..10).map(|i| i * 2).collect();
755        assert_eq!(values, expected);
756    }
757
758    #[tokio::test]
759    async fn send_spawned_error_propagation() {
760        let (tx, _rx) = mpsc::channel(100);
761        let result: Result<(), Break<String>> = stream::iter(0..10u64)
762            .try_for_each_send_spawned(
763                ConcurrencyConfig::fixed(1),
764                |i| async move {
765                    if i < 3 {
766                        Ok(i)
767                    } else {
768                        Err(Break::Err("fail".to_string()))
769                    }
770                },
771                tx,
772                |_| {},
773            )
774            .await;
775
776        assert!(matches!(result, Err(Break::Err(ref s)) if s == "fail"));
777    }
778
779    #[tokio::test]
780    async fn send_spawned_channel_closed() {
781        let (tx, rx) = mpsc::channel(1);
782        drop(rx);
783
784        let result: Result<(), Break<()>> = stream::iter(0..10u64)
785            .try_for_each_send_spawned(
786                ConcurrencyConfig::fixed(1),
787                |i| async move { Ok(i) },
788                tx,
789                |_| {},
790            )
791            .await;
792
793        assert!(matches!(result, Err(Break::Break)));
794    }
795
796    #[tokio::test]
797    async fn send_spawned_reports_stats() {
798        let reported: Arc<Mutex<Vec<ConcurrencyStats>>> = Arc::new(Mutex::new(Vec::new()));
799        let (tx, _rx) = mpsc::channel(100);
800
801        let reported2 = reported.clone();
802        let _ = stream::iter(0..5u64)
803            .try_for_each_send_spawned(
804                ConcurrencyConfig::fixed(2),
805                |i| async move { Ok::<_, Break<()>>(i) },
806                tx,
807                move |stats| {
808                    reported2.lock().unwrap().push(stats);
809                },
810            )
811            .await;
812
813        let reports = reported.lock().unwrap();
814        for stats in reports.iter() {
815            assert_eq!(stats.limit, 2);
816        }
817    }
818
819    #[tokio::test]
820    async fn broadcast_spawned_basic() {
821        let (tx1, mut rx1) = mpsc::channel(100);
822        let (tx2, mut rx2) = mpsc::channel(100);
823        let txs = vec![tx1, tx2];
824
825        let result = stream::iter(0..5u64)
826            .try_for_each_broadcast_spawned(
827                ConcurrencyConfig::fixed(2),
828                |i| async move { Ok::<_, Break<()>>(i * 3) },
829                txs,
830                |_| {},
831            )
832            .await;
833
834        assert!(result.is_ok());
835
836        let mut v1 = Vec::new();
837        while let Ok(v) = rx1.try_recv() {
838            v1.push(v);
839        }
840        let mut v2 = Vec::new();
841        while let Ok(v) = rx2.try_recv() {
842            v2.push(v);
843        }
844        v1.sort();
845        v2.sort();
846        let expected: Vec<u64> = (0..5).map(|i| i * 3).collect();
847        assert_eq!(v1, expected);
848        assert_eq!(v2, expected);
849    }
850
851    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
852    async fn send_spawned_adaptive_decreases_limit() {
853        // Use a small channel to force fill to rise quickly.
854        let (tx, mut rx) = mpsc::channel(4);
855        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
856
857        let limits2 = limits.clone();
858        let handle = tokio::spawn(async move {
859            stream::iter(0..100u64)
860                .try_for_each_send_spawned(
861                    ConcurrencyConfig::adaptive(10, 1, 20),
862                    |i| async move {
863                        // Simulate work so tasks take time
864                        tokio::time::sleep(Duration::from_millis(5)).await;
865                        Ok::<_, Break<()>>(i)
866                    },
867                    tx,
868                    move |stats| {
869                        limits2.lock().unwrap().push(stats.limit);
870                    },
871                )
872                .await
873        });
874
875        // Drain slowly so the channel fills up
876        let mut received = Vec::new();
877        loop {
878            tokio::time::sleep(Duration::from_millis(20)).await;
879            match rx.try_recv() {
880                Ok(v) => received.push(v),
881                Err(mpsc::error::TryRecvError::Empty) => {
882                    if handle.is_finished() {
883                        // Drain remaining
884                        while let Ok(v) = rx.try_recv() {
885                            received.push(v);
886                        }
887                        break;
888                    }
889                }
890                Err(mpsc::error::TryRecvError::Disconnected) => {
891                    while let Ok(v) = rx.try_recv() {
892                        received.push(v);
893                    }
894                    break;
895                }
896            }
897        }
898
899        handle.await.unwrap().unwrap();
900
901        let limits = limits.lock().unwrap();
902        let min_limit = limits.iter().copied().min().unwrap_or(10);
903        assert!(
904            min_limit < 10,
905            "Limit should have decreased from initial=10, min observed: {min_limit}"
906        );
907    }
908
909    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
910    async fn send_spawned_adaptive_recovers_after_decrease() {
911        let (tx, mut rx) = mpsc::channel(4);
912        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
913
914        let limits2 = limits.clone();
915        let handle = tokio::spawn(async move {
916            stream::iter(0..200u64)
917                .try_for_each_send_spawned(
918                    ConcurrencyConfig::adaptive(10, 1, 20),
919                    |i| async move {
920                        tokio::time::sleep(Duration::from_millis(5)).await;
921                        Ok::<_, Break<()>>(i)
922                    },
923                    tx,
924                    move |stats| {
925                        limits2.lock().unwrap().push(stats.limit);
926                    },
927                )
928                .await
929        });
930
931        // Phase 1: drain slowly so the channel fills and the limit decreases.
932        for _ in 0..60 {
933            tokio::time::sleep(Duration::from_millis(20)).await;
934            let _ = rx.try_recv();
935        }
936
937        // Record the lowest limit observed so far.
938        let low_water = {
939            let lims = limits.lock().unwrap();
940            lims.iter().copied().min().unwrap_or(10)
941        };
942        assert!(
943            low_water < 10,
944            "Limit should have decreased, min={low_water}"
945        );
946
947        // Phase 2: drain eagerly so fill drops and the limit recovers.
948        while (rx.recv().await).is_some() {
949            if handle.is_finished() {
950                while rx.try_recv().is_ok() {}
951                break;
952            }
953        }
954
955        handle.await.unwrap().unwrap();
956
957        let limits = limits.lock().unwrap();
958        let recovered_max = limits.iter().copied().rev().take(30).max().unwrap_or(0);
959        assert!(
960            recovered_max > low_water,
961            "Limit should have recovered above {low_water}, best late limit: {recovered_max}"
962        );
963    }
964
965    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
966    async fn send_spawned_adaptive_respects_min() {
967        let (tx, mut rx) = mpsc::channel(2);
968        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
969
970        let limits2 = limits.clone();
971        let handle = tokio::spawn(async move {
972            stream::iter(0..100u64)
973                .try_for_each_send_spawned(
974                    ConcurrencyConfig::adaptive(10, 5, 20),
975                    |i| async move {
976                        tokio::time::sleep(Duration::from_millis(10)).await;
977                        Ok::<_, Break<()>>(i)
978                    },
979                    tx,
980                    move |stats| {
981                        limits2.lock().unwrap().push(stats.limit);
982                    },
983                )
984                .await
985        });
986
987        // Drain very slowly to force congestion.
988        loop {
989            tokio::time::sleep(Duration::from_millis(50)).await;
990            match rx.try_recv() {
991                Ok(_) => {}
992                Err(mpsc::error::TryRecvError::Empty) => {
993                    if handle.is_finished() {
994                        while rx.try_recv().is_ok() {}
995                        break;
996                    }
997                }
998                Err(mpsc::error::TryRecvError::Disconnected) => {
999                    while rx.try_recv().is_ok() {}
1000                    break;
1001                }
1002            }
1003        }
1004
1005        handle.await.unwrap().unwrap();
1006
1007        let limits = limits.lock().unwrap();
1008        let min_limit = limits.iter().copied().min().unwrap_or(10);
1009        assert!(
1010            min_limit >= 5,
1011            "Limit should never drop below min=5, observed: {min_limit}"
1012        );
1013    }
1014
1015    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1016    async fn send_spawned_adaptive_respects_max() {
1017        let (tx, mut rx) = mpsc::channel(1000);
1018        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1019
1020        let limits2 = limits.clone();
1021        let handle = tokio::spawn(async move {
1022            stream::iter(0..200u64)
1023                .try_for_each_send_spawned(
1024                    ConcurrencyConfig::adaptive(2, 1, 8),
1025                    |i| async move {
1026                        tokio::time::sleep(Duration::from_millis(5)).await;
1027                        Ok::<_, Break<()>>(i)
1028                    },
1029                    tx,
1030                    move |stats| {
1031                        limits2.lock().unwrap().push(stats.limit);
1032                    },
1033                )
1034                .await
1035        });
1036
1037        // Drain eagerly so fill stays low and the limit keeps trying to increase.
1038        while rx.recv().await.is_some() {}
1039
1040        handle.await.unwrap().unwrap();
1041
1042        let limits = limits.lock().unwrap();
1043        let max_limit = limits.iter().copied().max().unwrap_or(0);
1044        assert!(
1045            max_limit <= 8,
1046            "Limit should never exceed max=8, observed: {max_limit}"
1047        );
1048    }
1049
1050    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1051    async fn send_spawned_epoch_prevents_stampede() {
1052        let (tx, mut rx) = mpsc::channel(2);
1053        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1054
1055        let limits2 = limits.clone();
1056        let handle = tokio::spawn(async move {
1057            stream::iter(0..60u64)
1058                .try_for_each_send_spawned(
1059                    ConcurrencyConfig::adaptive(20, 1, 20),
1060                    |i| async move {
1061                        tokio::time::sleep(Duration::from_millis(10)).await;
1062                        Ok::<_, Break<()>>(i)
1063                    },
1064                    tx,
1065                    move |stats| {
1066                        limits2.lock().unwrap().push(stats.limit);
1067                    },
1068                )
1069                .await
1070        });
1071
1072        // Don't drain initially so the channel fills up. Many tasks will complete while
1073        // the channel is full, exercising the epoch guard.
1074        tokio::time::sleep(Duration::from_millis(300)).await;
1075
1076        // Now drain to let the producer finish.
1077        while rx.recv().await.is_some() {}
1078        handle.await.unwrap().unwrap();
1079
1080        let limits = limits.lock().unwrap();
1081        // Deduplicate consecutive equal values to get actual transitions.
1082        let transitions: Vec<usize> = limits
1083            .iter()
1084            .copied()
1085            .collect::<Vec<_>>()
1086            .windows(2)
1087            .filter_map(|w| if w[0] != w[1] { Some(w[1]) } else { None })
1088            .collect();
1089
1090        // For every decrease, verify it's at most a single proportional cut.
1091        // The maximum single-step cut at fill=1.0 is: new = ceil(old * 0.5).
1092        for pair in limits.iter().copied().collect::<Vec<_>>().windows(2) {
1093            let (old, new) = (pair[0], pair[1]);
1094            if new < old {
1095                let min_allowed = ((old as f64) * 0.5).ceil() as usize;
1096                assert!(
1097                    new >= min_allowed,
1098                    "Stampede detected: limit dropped from {old} to {new}, \
1099                     minimum allowed single-step is {min_allowed}. Transitions: {transitions:?}"
1100                );
1101            }
1102        }
1103    }
1104
1105    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1106    async fn broadcast_spawned_slow_receiver_triggers_decrease() {
1107        let (tx_fast, mut rx_fast) = mpsc::channel(100);
1108        let (tx_slow, mut rx_slow) = mpsc::channel(4);
1109        let txs = vec![tx_fast, tx_slow];
1110        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1111
1112        let limits2 = limits.clone();
1113        let handle = tokio::spawn(async move {
1114            stream::iter(0..100u64)
1115                .try_for_each_broadcast_spawned(
1116                    ConcurrencyConfig::adaptive(10, 1, 20),
1117                    |i| async move {
1118                        tokio::time::sleep(Duration::from_millis(5)).await;
1119                        Ok::<_, Break<()>>(i)
1120                    },
1121                    txs,
1122                    move |stats| {
1123                        limits2.lock().unwrap().push(stats.limit);
1124                    },
1125                )
1126                .await
1127        });
1128
1129        // Drain the fast channel eagerly.
1130        let fast_drain = tokio::spawn(async move { while rx_fast.recv().await.is_some() {} });
1131
1132        // Drain the slow channel slowly.
1133        loop {
1134            tokio::time::sleep(Duration::from_millis(20)).await;
1135            match rx_slow.try_recv() {
1136                Ok(_) => {}
1137                Err(mpsc::error::TryRecvError::Empty) => {
1138                    if handle.is_finished() {
1139                        while rx_slow.try_recv().is_ok() {}
1140                        break;
1141                    }
1142                }
1143                Err(mpsc::error::TryRecvError::Disconnected) => {
1144                    while rx_slow.try_recv().is_ok() {}
1145                    break;
1146                }
1147            }
1148        }
1149
1150        handle.await.unwrap().unwrap();
1151        fast_drain.await.unwrap();
1152
1153        let limits = limits.lock().unwrap();
1154        let min_limit = limits.iter().copied().min().unwrap_or(10);
1155        assert!(
1156            min_limit < 10,
1157            "Limit should have decreased due to slow receiver, min observed: {min_limit}"
1158        );
1159    }
1160
1161    #[tokio::test]
1162    async fn broadcast_spawned_channel_closed() {
1163        let (tx1, _rx1) = mpsc::channel(100);
1164        let (tx2, rx2) = mpsc::channel(100);
1165        drop(rx2);
1166
1167        let result: Result<(), Break<()>> = stream::iter(0..10u64)
1168            .try_for_each_broadcast_spawned(
1169                ConcurrencyConfig::fixed(2),
1170                |i| async move { Ok(i) },
1171                vec![tx1, tx2],
1172                |_| {},
1173            )
1174            .await;
1175
1176        assert!(matches!(result, Err(Break::Break)));
1177    }
1178
1179    #[tokio::test]
1180    async fn fixed_concurrency_limit_never_changes() {
1181        let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1182        let (tx, mut rx) = mpsc::channel(2);
1183
1184        let limits2 = limits.clone();
1185        let handle = tokio::spawn(async move {
1186            stream::iter(0..20u64)
1187                .try_for_each_send_spawned(
1188                    ConcurrencyConfig::fixed(5),
1189                    |i| async move { Ok::<_, Break<()>>(i) },
1190                    tx,
1191                    move |stats| {
1192                        limits2.lock().unwrap().push(stats.limit);
1193                    },
1194                )
1195                .await
1196        });
1197
1198        // Drain the receiver so sends don't block.
1199        while rx.recv().await.is_some() {}
1200
1201        handle.await.unwrap().unwrap();
1202
1203        let limits = limits.lock().unwrap();
1204        for &g in limits.iter() {
1205            assert_eq!(g, 5, "Fixed limit should never change");
1206        }
1207    }
1208}