1use std::{future::Future, time::Duration};
5
6use tokio::time::sleep;
7
8pub async fn with_slow_future_monitor<F, C>(
14 future: F,
15 threshold: Duration,
16 callback: C,
17) -> F::Output
18where
19 F: Future,
20 C: FnOnce(),
21{
22 tokio::pin!(future);
24
25 tokio::select! {
26 result = &mut future => {
27 return result;
29 }
30 _ = sleep(threshold) => {
31 callback();
33 }
34 }
35
36 future.await
39}
40
41#[cfg(test)]
42mod tests {
43 use std::{
44 sync::{
45 Arc,
46 atomic::{AtomicUsize, Ordering},
47 },
48 time::Duration,
49 };
50
51 use tokio::time::{sleep, timeout};
52
53 use super::*;
54
55 #[derive(Clone)]
56 struct Counter(Arc<AtomicUsize>);
57
58 impl Counter {
59 fn new() -> Self {
60 Self(Arc::new(AtomicUsize::new(0)))
61 }
62
63 fn increment(&self) {
64 self.0.fetch_add(1, Ordering::Relaxed);
65 }
66
67 fn count(&self) -> usize {
68 self.0.load(Ordering::Relaxed)
69 }
70 }
71
72 #[tokio::test]
73 async fn slow_monitor_callback_called_once_when_threshold_exceeded() {
74 let c = Counter::new();
75
76 let result = with_slow_future_monitor(
77 async {
78 sleep(Duration::from_millis(200)).await;
79 42 },
81 Duration::from_millis(100),
82 || c.increment(),
83 )
84 .await;
85
86 assert_eq!(c.count(), 1);
87 assert_eq!(result, 42);
88 }
89
90 #[tokio::test]
91 async fn slow_monitor_callback_not_called_when_threshold_not_exceeded() {
92 let c = Counter::new();
93
94 let result = with_slow_future_monitor(
95 async {
96 sleep(Duration::from_millis(50)).await;
97 42 },
99 Duration::from_millis(200),
100 || c.increment(),
101 )
102 .await;
103
104 assert_eq!(c.count(), 0);
105 assert_eq!(result, 42);
106 }
107
108 #[tokio::test]
109 async fn slow_monitor_error_propagation() {
110 let c = Counter::new();
111
112 let result: Result<i32, &str> = with_slow_future_monitor(
113 async {
114 sleep(Duration::from_millis(150)).await;
115 Err("Something went wrong")
116 },
117 Duration::from_millis(100),
118 || c.increment(),
119 )
120 .await;
121
122 assert!(result.is_err());
123 assert_eq!(result.unwrap_err(), "Something went wrong");
124 assert_eq!(c.count(), 1);
125 }
126
127 #[tokio::test]
128 async fn slow_monitor_error_propagation_without_callback() {
129 let c = Counter::new();
130
131 let result: Result<i32, &str> = with_slow_future_monitor(
132 async {
133 sleep(Duration::from_millis(50)).await;
134 Err("Quick error")
135 },
136 Duration::from_millis(200),
137 || c.increment(),
138 )
139 .await;
140
141 assert!(result.is_err());
142 assert_eq!(result.unwrap_err(), "Quick error");
143 assert_eq!(c.count(), 0);
144 }
145
146 #[tokio::test]
147 async fn slow_monitor_stuck_future_detection() {
148 use std::future::Future;
149 use std::pin::Pin;
150 use std::task::{Context, Poll};
151
152 struct StuckFuture;
154 impl Future for StuckFuture {
155 type Output = ();
156 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
157 Poll::Pending
158 }
159 }
160
161 let c = Counter::new();
162
163 let monitored =
165 with_slow_future_monitor(StuckFuture, Duration::from_millis(200), || c.increment());
166
167 timeout(Duration::from_secs(2), monitored)
169 .await
170 .unwrap_err();
171 assert_eq!(c.count(), 1);
172 }
173}