1use std::future::Future;
5use std::future::poll_fn;
6use std::panic;
7use std::pin::pin;
8use std::sync::Arc;
9
10use futures::FutureExt;
11use futures::future::try_join_all;
12use futures::stream::Stream;
13use futures::try_join;
14use tokio::sync::mpsc;
15use tokio::task::JoinSet;
16
17#[derive(Debug, Clone)]
29pub struct ConcurrencyConfig {
30 pub initial: usize,
31 pub min: usize,
32 pub max: usize,
33 pub dead_band_low: f64,
36 pub dead_band_high: f64,
38}
39
40#[derive(Debug, Clone, Copy)]
42pub struct ConcurrencyStats {
43 pub limit: usize,
44 pub inflight: usize,
45}
46
47#[derive(Debug)]
50pub enum Break<E> {
51 Break,
52 Err(E),
53}
54
55pub trait TrySpawnStreamExt: Stream {
57 fn try_for_each_spawned<Fut, F, E>(
74 self,
75 limit: impl Into<Option<usize>>,
76 f: F,
77 ) -> impl Future<Output = Result<(), E>>
78 where
79 Fut: Future<Output = Result<(), E>> + Send + 'static,
80 F: FnMut(Self::Item) -> Fut,
81 E: Send + 'static;
82
83 fn try_for_each_send_spawned<Fut, F, T, E, R>(
95 self,
96 config: ConcurrencyConfig,
97 f: F,
98 tx: mpsc::Sender<T>,
99 report: R,
100 ) -> impl Future<Output = Result<(), Break<E>>>
101 where
102 Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
103 F: FnMut(Self::Item) -> Fut,
104 T: Send + 'static,
105 E: Send + 'static,
106 R: Fn(ConcurrencyStats);
107
108 fn try_for_each_broadcast_spawned<Fut, F, T, E, R>(
114 self,
115 config: ConcurrencyConfig,
116 f: F,
117 txs: Vec<mpsc::Sender<T>>,
118 report: R,
119 ) -> impl Future<Output = Result<(), Break<E>>>
120 where
121 Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
122 F: FnMut(Self::Item) -> Fut,
123 T: Clone + Send + Sync + 'static,
124 E: Send + 'static,
125 R: Fn(ConcurrencyStats);
126}
127
128trait Sender: Clone + Send + Sync + 'static {
134 type Value: Send + 'static;
135
136 fn send(&self, value: Self::Value) -> impl Future<Output = Result<(), ()>> + Send;
138
139 fn fill(&self) -> f64;
142}
143
144struct SingleSender<T>(mpsc::Sender<T>);
146
147struct BroadcastSender<T>(Arc<Vec<mpsc::Sender<T>>>);
149
150impl ConcurrencyConfig {
151 pub fn fixed(n: usize) -> Self {
152 Self {
153 initial: n,
154 min: n,
155 max: n,
156 dead_band_low: 0.6,
157 dead_band_high: 0.85,
158 }
159 }
160
161 pub fn adaptive(initial: usize, min: usize, max: usize) -> Self {
162 Self {
163 initial,
164 min,
165 max,
166 dead_band_low: 0.6,
167 dead_band_high: 0.85,
168 }
169 }
170
171 pub fn with_dead_band(mut self, low: f64, high: f64) -> Self {
172 self.dead_band_low = low;
173 self.dead_band_high = high;
174 self
175 }
176}
177
178impl<E> From<E> for Break<E> {
179 fn from(e: E) -> Self {
180 Break::Err(e)
181 }
182}
183
184impl<S: Stream + Sized + 'static> TrySpawnStreamExt for S {
185 async fn try_for_each_spawned<Fut, F, E>(
186 self,
187 limit: impl Into<Option<usize>>,
188 mut f: F,
189 ) -> Result<(), E>
190 where
191 Fut: Future<Output = Result<(), E>> + Send + 'static,
192 F: FnMut(Self::Item) -> Fut,
193 E: Send + 'static,
194 {
195 let limit = match limit.into() {
197 Some(0) | None => usize::MAX,
198 Some(n) => n,
199 };
200
201 let mut permits = limit;
203 let mut join_set = JoinSet::new();
205 let mut draining = false;
207 let mut error = None;
209
210 let mut self_ = pin!(self);
211
212 loop {
213 while !draining && permits > 0 {
216 match poll_fn(|cx| self_.as_mut().poll_next(cx)).now_or_never() {
217 Some(Some(item)) => {
218 permits -= 1;
219 join_set.spawn(f(item));
220 }
221 Some(None) => {
222 draining = true;
226 }
227 None => break,
228 }
229 }
230
231 tokio::select! {
232 biased;
233
234 Some(res) = join_set.join_next() => {
235 match res {
236 Ok(Err(e)) if error.is_none() => {
237 error = Some(e);
238 permits += 1;
239 draining = true;
240 }
241
242 Ok(_) => permits += 1,
243
244 Err(e) if e.is_panic() => {
246 panic::resume_unwind(e.into_panic())
247 }
248
249 Err(e) => {
254 assert!(e.is_cancelled());
255 permits += 1;
256 draining = true;
257 }
258 }
259 }
260
261 next = poll_fn(|cx| self_.as_mut().poll_next(cx)),
262 if !draining && permits > 0 => {
263 if let Some(item) = next {
264 permits -= 1;
265 join_set.spawn(f(item));
266 } else {
267 draining = true;
268 }
269 }
270
271 else => {
272 if permits == limit && draining {
273 break;
274 }
275 }
276 }
277 }
278
279 if let Some(e) = error { Err(e) } else { Ok(()) }
280 }
281
282 async fn try_for_each_send_spawned<Fut, F, T, E, R>(
283 self,
284 config: ConcurrencyConfig,
285 f: F,
286 tx: mpsc::Sender<T>,
287 report: R,
288 ) -> Result<(), Break<E>>
289 where
290 Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
291 F: FnMut(Self::Item) -> Fut,
292 T: Send + 'static,
293 E: Send + 'static,
294 R: Fn(ConcurrencyStats),
295 {
296 adaptive_spawn_send(self, config, f, SingleSender(tx), report).await
297 }
298
299 async fn try_for_each_broadcast_spawned<Fut, F, T, E, R>(
300 self,
301 config: ConcurrencyConfig,
302 f: F,
303 txs: Vec<mpsc::Sender<T>>,
304 report: R,
305 ) -> Result<(), Break<E>>
306 where
307 Fut: Future<Output = Result<T, Break<E>>> + Send + 'static,
308 F: FnMut(Self::Item) -> Fut,
309 T: Clone + Send + Sync + 'static,
310 E: Send + 'static,
311 R: Fn(ConcurrencyStats),
312 {
313 adaptive_spawn_send(self, config, f, BroadcastSender(Arc::new(txs)), report).await
314 }
315}
316
317impl<T> Clone for SingleSender<T> {
318 fn clone(&self) -> Self {
319 Self(self.0.clone())
320 }
321}
322
323impl<T: Send + 'static> Sender for SingleSender<T> {
324 type Value = T;
325
326 async fn send(&self, value: T) -> Result<(), ()> {
327 self.0.send(value).await.map_err(|_| ())
328 }
329
330 fn fill(&self) -> f64 {
331 1.0 - (self.0.capacity() as f64 / self.0.max_capacity() as f64)
332 }
333}
334
335impl<T> Clone for BroadcastSender<T> {
336 fn clone(&self) -> Self {
337 Self(self.0.clone())
338 }
339}
340
341impl<T: Clone + Send + Sync + 'static> Sender for BroadcastSender<T> {
342 type Value = T;
343
344 async fn send(&self, value: T) -> Result<(), ()> {
345 let (last, rest) = self.0.split_last().ok_or(())?;
346 let rest_fut = try_join_all(rest.iter().map(|tx| {
347 let v = value.clone();
348 async move { tx.send(v).await.map_err(|_| ()) }
349 }));
350 let last_fut = last.send(value).map(|r| r.map_err(|_| ()));
351 try_join!(rest_fut, last_fut)?;
352 Ok(())
353 }
354
355 fn fill(&self) -> f64 {
356 self.0
357 .iter()
358 .map(|tx| 1.0 - (tx.capacity() as f64 / tx.max_capacity() as f64))
359 .fold(0.0f64, f64::max)
360 }
361}
362
363async fn adaptive_spawn_send<S, Fut, F, E, Tx, R>(
435 stream: S,
436 config: ConcurrencyConfig,
437 mut f: F,
438 sender: Tx,
439 report: R,
440) -> Result<(), Break<E>>
441where
442 S: Stream + 'static,
443 Fut: Future<Output = Result<Tx::Value, Break<E>>> + Send + 'static,
444 F: FnMut(S::Item) -> Fut,
445 E: Send + 'static,
446 Tx: Sender,
447 R: Fn(ConcurrencyStats),
448{
449 assert!(config.min >= 1, "ConcurrencyConfig::min must be >= 1");
450 let mut limit = config.initial;
451 let mut epoch: u64 = 0;
452 let mut was_saturated = false;
453 let mut tasks: JoinSet<Result<u64, Break<E>>> = JoinSet::new();
454 let mut stream_done = false;
455 let mut error: Option<Break<E>> = None;
456
457 let mut stream = pin!(stream);
458
459 loop {
460 if tasks.is_empty() && (stream_done || error.is_some()) {
461 break;
462 }
463
464 while tasks.len() < limit && !stream_done && error.is_none() {
466 match poll_fn(|cx| stream.as_mut().poll_next(cx)).now_or_never() {
467 Some(Some(item)) => {
468 let fut = f(item);
469 let tx = sender.clone();
470 let spawn_epoch = epoch;
471 tasks.spawn(async move {
472 let value = fut.await?;
473 tx.send(value).await.map_err(|_| Break::Break)?;
474 Ok(spawn_epoch)
475 });
476 if tasks.len() >= limit {
477 was_saturated = true;
478 }
479 }
480 Some(None) => stream_done = true,
481 None => break,
482 }
483 }
484
485 let completed = tokio::select! {
486 biased;
487
488 Some(r) = tasks.join_next(), if !tasks.is_empty() => Some(r),
489
490 next = poll_fn(|cx| stream.as_mut().poll_next(cx)),
491 if tasks.len() < limit && !stream_done && error.is_none() =>
492 {
493 if let Some(item) = next {
494 let fut = f(item);
495 let tx = sender.clone();
496 let spawn_epoch = epoch;
497 tasks.spawn(async move {
498 let value = fut.await?;
499 tx.send(value).await.map_err(|_| Break::Break)?;
500 Ok(spawn_epoch)
501 });
502 if tasks.len() >= limit {
503 was_saturated = true;
504 }
505 } else {
506 stream_done = true;
507 }
508 None
509 }
510
511 else => {
512 if tasks.is_empty() && (stream_done || error.is_some()) {
513 break;
514 }
515 None
516 }
517 };
518
519 for join_result in completed.into_iter().chain(std::iter::from_fn(|| {
521 tasks.join_next().now_or_never().flatten()
522 })) {
523 match join_result {
524 Ok(Ok(spawn_epoch)) => {
525 let fill = sender.fill();
531 if fill >= config.dead_band_high && spawn_epoch == epoch {
532 let severity =
537 (fill - config.dead_band_high) / (1.0 - config.dead_band_high);
538 let keep = 0.8 - 0.3 * severity;
539 let new_limit = ((limit as f64) * keep).ceil() as usize;
540 limit = new_limit.min(limit.saturating_sub(1)).max(config.min);
541 limit = limit.clamp(config.min, config.max);
542 epoch += 1;
543 was_saturated = false;
550 } else if fill < config.dead_band_low && was_saturated {
551 let increment = ((limit as f64).log10().ceil() as usize).max(1);
552 limit = (limit + increment).min(config.max);
553 was_saturated = false;
554 }
555 }
556 Ok(Err(e)) if error.is_none() => error = Some(e),
557 Ok(Err(_)) => {}
558 Err(e) if e.is_panic() => panic::resume_unwind(e.into_panic()),
559 Err(e) => {
560 assert!(e.is_cancelled());
561 stream_done = true;
562 }
563 }
564 }
565
566 report(ConcurrencyStats {
567 limit,
568 inflight: tasks.len(),
569 });
570 }
571
572 if let Some(e) = error { Err(e) } else { Ok(()) }
573}
574
575#[cfg(test)]
576mod tests {
577 use std::{
578 sync::{
579 Arc, Mutex,
580 atomic::{AtomicUsize, Ordering},
581 },
582 time::Duration,
583 };
584
585 use futures::stream;
586
587 use super::*;
588
589 #[tokio::test]
590 async fn for_each_explicit_sequential_iteration() {
591 let actual = Arc::new(Mutex::new(vec![]));
592 let result = stream::iter(0..20)
593 .try_for_each_spawned(1, |i| {
594 let actual = actual.clone();
595 async move {
596 tokio::time::sleep(Duration::from_millis(20 - i)).await;
597 actual.lock().unwrap().push(i);
598 Ok::<(), ()>(())
599 }
600 })
601 .await;
602
603 assert!(result.is_ok());
604
605 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
606 let expect: Vec<_> = (0..20).collect();
607 assert_eq!(expect, actual);
608 }
609
610 #[tokio::test]
611 async fn for_each_concurrent_iteration() {
612 let actual = Arc::new(AtomicUsize::new(0));
613 let result = stream::iter(0..100)
614 .try_for_each_spawned(16, |i| {
615 let actual = actual.clone();
616 async move {
617 actual.fetch_add(i, Ordering::Relaxed);
618 Ok::<(), ()>(())
619 }
620 })
621 .await;
622
623 assert!(result.is_ok());
624
625 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
626 let expect = 99 * 100 / 2;
627 assert_eq!(expect, actual);
628 }
629
630 #[tokio::test]
631 async fn for_each_implicit_unlimited_iteration() {
632 let actual = Arc::new(AtomicUsize::new(0));
633 let result = stream::iter(0..100)
634 .try_for_each_spawned(None, |i| {
635 let actual = actual.clone();
636 async move {
637 actual.fetch_add(i, Ordering::Relaxed);
638 Ok::<(), ()>(())
639 }
640 })
641 .await;
642
643 assert!(result.is_ok());
644
645 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
646 let expect = 99 * 100 / 2;
647 assert_eq!(expect, actual);
648 }
649
650 #[tokio::test]
651 async fn for_each_explicit_unlimited_iteration() {
652 let actual = Arc::new(AtomicUsize::new(0));
653 let result = stream::iter(0..100)
654 .try_for_each_spawned(0, |i| {
655 let actual = actual.clone();
656 async move {
657 actual.fetch_add(i, Ordering::Relaxed);
658 Ok::<(), ()>(())
659 }
660 })
661 .await;
662
663 assert!(result.is_ok());
664
665 let actual = Arc::try_unwrap(actual).unwrap().into_inner();
666 let expect = 99 * 100 / 2;
667 assert_eq!(expect, actual);
668 }
669
670 #[tokio::test]
671 async fn for_each_max_concurrency() {
672 #[derive(Default, Debug)]
673 struct Jobs {
674 max: AtomicUsize,
675 curr: AtomicUsize,
676 }
677
678 let jobs = Arc::new(Jobs::default());
679
680 let result = stream::iter(0..32)
681 .try_for_each_spawned(4, |_| {
682 let jobs = jobs.clone();
683 async move {
684 jobs.curr.fetch_add(1, Ordering::Relaxed);
685 tokio::time::sleep(Duration::from_millis(100)).await;
686 let prev = jobs.curr.fetch_sub(1, Ordering::Relaxed);
687 jobs.max.fetch_max(prev, Ordering::Relaxed);
688 Ok::<(), ()>(())
689 }
690 })
691 .await;
692
693 assert!(result.is_ok());
694
695 let Jobs { max, curr } = Arc::try_unwrap(jobs).unwrap();
696 assert_eq!(curr.into_inner(), 0);
697 assert!(max.into_inner() <= 4);
698 }
699
700 #[tokio::test]
701 async fn for_each_error_propagation() {
702 let actual = Arc::new(Mutex::new(vec![]));
703 let result = stream::iter(0..100)
704 .try_for_each_spawned(None, |i| {
705 let actual = actual.clone();
706 async move {
707 if i < 42 {
708 actual.lock().unwrap().push(i);
709 Ok(())
710 } else {
711 Err(())
712 }
713 }
714 })
715 .await;
716
717 assert!(result.is_err());
718
719 let actual = Arc::try_unwrap(actual).unwrap().into_inner().unwrap();
720 let expect: Vec<_> = (0..42).collect();
721 assert_eq!(expect, actual);
722 }
723
724 #[tokio::test]
725 #[should_panic]
726 async fn for_each_panic_propagation() {
727 let _ = stream::iter(0..100)
728 .try_for_each_spawned(None, |i| async move {
729 assert!(i < 42);
730 Ok::<(), ()>(())
731 })
732 .await;
733 }
734
735 #[tokio::test]
736 async fn send_spawned_basic() {
737 let (tx, mut rx) = mpsc::channel(100);
738 let result = stream::iter(0..10u64)
739 .try_for_each_send_spawned(
740 ConcurrencyConfig::fixed(4),
741 |i| async move { Ok::<_, Break<()>>(i * 2) },
742 tx,
743 |_| {},
744 )
745 .await;
746
747 assert!(result.is_ok());
748
749 let mut values = Vec::new();
750 while let Ok(v) = rx.try_recv() {
751 values.push(v);
752 }
753 values.sort();
754 let expected: Vec<u64> = (0..10).map(|i| i * 2).collect();
755 assert_eq!(values, expected);
756 }
757
758 #[tokio::test]
759 async fn send_spawned_error_propagation() {
760 let (tx, _rx) = mpsc::channel(100);
761 let result: Result<(), Break<String>> = stream::iter(0..10u64)
762 .try_for_each_send_spawned(
763 ConcurrencyConfig::fixed(1),
764 |i| async move {
765 if i < 3 {
766 Ok(i)
767 } else {
768 Err(Break::Err("fail".to_string()))
769 }
770 },
771 tx,
772 |_| {},
773 )
774 .await;
775
776 assert!(matches!(result, Err(Break::Err(ref s)) if s == "fail"));
777 }
778
779 #[tokio::test]
780 async fn send_spawned_channel_closed() {
781 let (tx, rx) = mpsc::channel(1);
782 drop(rx);
783
784 let result: Result<(), Break<()>> = stream::iter(0..10u64)
785 .try_for_each_send_spawned(
786 ConcurrencyConfig::fixed(1),
787 |i| async move { Ok(i) },
788 tx,
789 |_| {},
790 )
791 .await;
792
793 assert!(matches!(result, Err(Break::Break)));
794 }
795
796 #[tokio::test]
797 async fn send_spawned_reports_stats() {
798 let reported: Arc<Mutex<Vec<ConcurrencyStats>>> = Arc::new(Mutex::new(Vec::new()));
799 let (tx, _rx) = mpsc::channel(100);
800
801 let reported2 = reported.clone();
802 let _ = stream::iter(0..5u64)
803 .try_for_each_send_spawned(
804 ConcurrencyConfig::fixed(2),
805 |i| async move { Ok::<_, Break<()>>(i) },
806 tx,
807 move |stats| {
808 reported2.lock().unwrap().push(stats);
809 },
810 )
811 .await;
812
813 let reports = reported.lock().unwrap();
814 for stats in reports.iter() {
815 assert_eq!(stats.limit, 2);
816 }
817 }
818
819 #[tokio::test]
820 async fn broadcast_spawned_basic() {
821 let (tx1, mut rx1) = mpsc::channel(100);
822 let (tx2, mut rx2) = mpsc::channel(100);
823 let txs = vec![tx1, tx2];
824
825 let result = stream::iter(0..5u64)
826 .try_for_each_broadcast_spawned(
827 ConcurrencyConfig::fixed(2),
828 |i| async move { Ok::<_, Break<()>>(i * 3) },
829 txs,
830 |_| {},
831 )
832 .await;
833
834 assert!(result.is_ok());
835
836 let mut v1 = Vec::new();
837 while let Ok(v) = rx1.try_recv() {
838 v1.push(v);
839 }
840 let mut v2 = Vec::new();
841 while let Ok(v) = rx2.try_recv() {
842 v2.push(v);
843 }
844 v1.sort();
845 v2.sort();
846 let expected: Vec<u64> = (0..5).map(|i| i * 3).collect();
847 assert_eq!(v1, expected);
848 assert_eq!(v2, expected);
849 }
850
851 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
852 async fn send_spawned_adaptive_decreases_limit() {
853 let (tx, mut rx) = mpsc::channel(4);
855 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
856
857 let limits2 = limits.clone();
858 let handle = tokio::spawn(async move {
859 stream::iter(0..100u64)
860 .try_for_each_send_spawned(
861 ConcurrencyConfig::adaptive(10, 1, 20),
862 |i| async move {
863 tokio::time::sleep(Duration::from_millis(5)).await;
865 Ok::<_, Break<()>>(i)
866 },
867 tx,
868 move |stats| {
869 limits2.lock().unwrap().push(stats.limit);
870 },
871 )
872 .await
873 });
874
875 let mut received = Vec::new();
877 loop {
878 tokio::time::sleep(Duration::from_millis(20)).await;
879 match rx.try_recv() {
880 Ok(v) => received.push(v),
881 Err(mpsc::error::TryRecvError::Empty) => {
882 if handle.is_finished() {
883 while let Ok(v) = rx.try_recv() {
885 received.push(v);
886 }
887 break;
888 }
889 }
890 Err(mpsc::error::TryRecvError::Disconnected) => {
891 while let Ok(v) = rx.try_recv() {
892 received.push(v);
893 }
894 break;
895 }
896 }
897 }
898
899 handle.await.unwrap().unwrap();
900
901 let limits = limits.lock().unwrap();
902 let min_limit = limits.iter().copied().min().unwrap_or(10);
903 assert!(
904 min_limit < 10,
905 "Limit should have decreased from initial=10, min observed: {min_limit}"
906 );
907 }
908
909 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
910 async fn send_spawned_adaptive_recovers_after_decrease() {
911 let (tx, mut rx) = mpsc::channel(4);
912 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
913
914 let limits2 = limits.clone();
915 let handle = tokio::spawn(async move {
916 stream::iter(0..200u64)
917 .try_for_each_send_spawned(
918 ConcurrencyConfig::adaptive(10, 1, 20),
919 |i| async move {
920 tokio::time::sleep(Duration::from_millis(5)).await;
921 Ok::<_, Break<()>>(i)
922 },
923 tx,
924 move |stats| {
925 limits2.lock().unwrap().push(stats.limit);
926 },
927 )
928 .await
929 });
930
931 for _ in 0..60 {
933 tokio::time::sleep(Duration::from_millis(20)).await;
934 let _ = rx.try_recv();
935 }
936
937 let low_water = {
939 let lims = limits.lock().unwrap();
940 lims.iter().copied().min().unwrap_or(10)
941 };
942 assert!(
943 low_water < 10,
944 "Limit should have decreased, min={low_water}"
945 );
946
947 while (rx.recv().await).is_some() {
949 if handle.is_finished() {
950 while rx.try_recv().is_ok() {}
951 break;
952 }
953 }
954
955 handle.await.unwrap().unwrap();
956
957 let limits = limits.lock().unwrap();
958 let recovered_max = limits.iter().copied().rev().take(30).max().unwrap_or(0);
959 assert!(
960 recovered_max > low_water,
961 "Limit should have recovered above {low_water}, best late limit: {recovered_max}"
962 );
963 }
964
965 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
966 async fn send_spawned_adaptive_respects_min() {
967 let (tx, mut rx) = mpsc::channel(2);
968 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
969
970 let limits2 = limits.clone();
971 let handle = tokio::spawn(async move {
972 stream::iter(0..100u64)
973 .try_for_each_send_spawned(
974 ConcurrencyConfig::adaptive(10, 5, 20),
975 |i| async move {
976 tokio::time::sleep(Duration::from_millis(10)).await;
977 Ok::<_, Break<()>>(i)
978 },
979 tx,
980 move |stats| {
981 limits2.lock().unwrap().push(stats.limit);
982 },
983 )
984 .await
985 });
986
987 loop {
989 tokio::time::sleep(Duration::from_millis(50)).await;
990 match rx.try_recv() {
991 Ok(_) => {}
992 Err(mpsc::error::TryRecvError::Empty) => {
993 if handle.is_finished() {
994 while rx.try_recv().is_ok() {}
995 break;
996 }
997 }
998 Err(mpsc::error::TryRecvError::Disconnected) => {
999 while rx.try_recv().is_ok() {}
1000 break;
1001 }
1002 }
1003 }
1004
1005 handle.await.unwrap().unwrap();
1006
1007 let limits = limits.lock().unwrap();
1008 let min_limit = limits.iter().copied().min().unwrap_or(10);
1009 assert!(
1010 min_limit >= 5,
1011 "Limit should never drop below min=5, observed: {min_limit}"
1012 );
1013 }
1014
1015 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1016 async fn send_spawned_adaptive_respects_max() {
1017 let (tx, mut rx) = mpsc::channel(1000);
1018 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1019
1020 let limits2 = limits.clone();
1021 let handle = tokio::spawn(async move {
1022 stream::iter(0..200u64)
1023 .try_for_each_send_spawned(
1024 ConcurrencyConfig::adaptive(2, 1, 8),
1025 |i| async move {
1026 tokio::time::sleep(Duration::from_millis(5)).await;
1027 Ok::<_, Break<()>>(i)
1028 },
1029 tx,
1030 move |stats| {
1031 limits2.lock().unwrap().push(stats.limit);
1032 },
1033 )
1034 .await
1035 });
1036
1037 while rx.recv().await.is_some() {}
1039
1040 handle.await.unwrap().unwrap();
1041
1042 let limits = limits.lock().unwrap();
1043 let max_limit = limits.iter().copied().max().unwrap_or(0);
1044 assert!(
1045 max_limit <= 8,
1046 "Limit should never exceed max=8, observed: {max_limit}"
1047 );
1048 }
1049
1050 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1051 async fn send_spawned_epoch_prevents_stampede() {
1052 let (tx, mut rx) = mpsc::channel(2);
1053 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1054
1055 let limits2 = limits.clone();
1056 let handle = tokio::spawn(async move {
1057 stream::iter(0..60u64)
1058 .try_for_each_send_spawned(
1059 ConcurrencyConfig::adaptive(20, 1, 20),
1060 |i| async move {
1061 tokio::time::sleep(Duration::from_millis(10)).await;
1062 Ok::<_, Break<()>>(i)
1063 },
1064 tx,
1065 move |stats| {
1066 limits2.lock().unwrap().push(stats.limit);
1067 },
1068 )
1069 .await
1070 });
1071
1072 tokio::time::sleep(Duration::from_millis(300)).await;
1075
1076 while rx.recv().await.is_some() {}
1078 handle.await.unwrap().unwrap();
1079
1080 let limits = limits.lock().unwrap();
1081 let transitions: Vec<usize> = limits
1083 .iter()
1084 .copied()
1085 .collect::<Vec<_>>()
1086 .windows(2)
1087 .filter_map(|w| if w[0] != w[1] { Some(w[1]) } else { None })
1088 .collect();
1089
1090 for pair in limits.iter().copied().collect::<Vec<_>>().windows(2) {
1093 let (old, new) = (pair[0], pair[1]);
1094 if new < old {
1095 let min_allowed = ((old as f64) * 0.5).ceil() as usize;
1096 assert!(
1097 new >= min_allowed,
1098 "Stampede detected: limit dropped from {old} to {new}, \
1099 minimum allowed single-step is {min_allowed}. Transitions: {transitions:?}"
1100 );
1101 }
1102 }
1103 }
1104
1105 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1106 async fn broadcast_spawned_slow_receiver_triggers_decrease() {
1107 let (tx_fast, mut rx_fast) = mpsc::channel(100);
1108 let (tx_slow, mut rx_slow) = mpsc::channel(4);
1109 let txs = vec![tx_fast, tx_slow];
1110 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1111
1112 let limits2 = limits.clone();
1113 let handle = tokio::spawn(async move {
1114 stream::iter(0..100u64)
1115 .try_for_each_broadcast_spawned(
1116 ConcurrencyConfig::adaptive(10, 1, 20),
1117 |i| async move {
1118 tokio::time::sleep(Duration::from_millis(5)).await;
1119 Ok::<_, Break<()>>(i)
1120 },
1121 txs,
1122 move |stats| {
1123 limits2.lock().unwrap().push(stats.limit);
1124 },
1125 )
1126 .await
1127 });
1128
1129 let fast_drain = tokio::spawn(async move { while rx_fast.recv().await.is_some() {} });
1131
1132 loop {
1134 tokio::time::sleep(Duration::from_millis(20)).await;
1135 match rx_slow.try_recv() {
1136 Ok(_) => {}
1137 Err(mpsc::error::TryRecvError::Empty) => {
1138 if handle.is_finished() {
1139 while rx_slow.try_recv().is_ok() {}
1140 break;
1141 }
1142 }
1143 Err(mpsc::error::TryRecvError::Disconnected) => {
1144 while rx_slow.try_recv().is_ok() {}
1145 break;
1146 }
1147 }
1148 }
1149
1150 handle.await.unwrap().unwrap();
1151 fast_drain.await.unwrap();
1152
1153 let limits = limits.lock().unwrap();
1154 let min_limit = limits.iter().copied().min().unwrap_or(10);
1155 assert!(
1156 min_limit < 10,
1157 "Limit should have decreased due to slow receiver, min observed: {min_limit}"
1158 );
1159 }
1160
1161 #[tokio::test]
1162 async fn broadcast_spawned_channel_closed() {
1163 let (tx1, _rx1) = mpsc::channel(100);
1164 let (tx2, rx2) = mpsc::channel(100);
1165 drop(rx2);
1166
1167 let result: Result<(), Break<()>> = stream::iter(0..10u64)
1168 .try_for_each_broadcast_spawned(
1169 ConcurrencyConfig::fixed(2),
1170 |i| async move { Ok(i) },
1171 vec![tx1, tx2],
1172 |_| {},
1173 )
1174 .await;
1175
1176 assert!(matches!(result, Err(Break::Break)));
1177 }
1178
1179 #[tokio::test]
1180 async fn fixed_concurrency_limit_never_changes() {
1181 let limits: Arc<Mutex<Vec<usize>>> = Arc::new(Mutex::new(Vec::new()));
1182 let (tx, mut rx) = mpsc::channel(2);
1183
1184 let limits2 = limits.clone();
1185 let handle = tokio::spawn(async move {
1186 stream::iter(0..20u64)
1187 .try_for_each_send_spawned(
1188 ConcurrencyConfig::fixed(5),
1189 |i| async move { Ok::<_, Break<()>>(i) },
1190 tx,
1191 move |stats| {
1192 limits2.lock().unwrap().push(stats.limit);
1193 },
1194 )
1195 .await
1196 });
1197
1198 while rx.recv().await.is_some() {}
1200
1201 handle.await.unwrap().unwrap();
1202
1203 let limits = limits.lock().unwrap();
1204 for &g in limits.iter() {
1205 assert_eq!(g, 5, "Fixed limit should never change");
1206 }
1207 }
1208}