1use crate::ZipDebugEqIteratorExt;
5use crate::debug_fatal;
6
7use futures::future::{Either, join_all};
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::Mutex;
10use parking_lot::MutexGuard;
11use std::collections::HashMap;
12use std::collections::HashSet;
13use std::collections::hash_map::DefaultHasher;
14use std::future::Future;
15use std::hash::{Hash, Hasher};
16use std::mem;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::sync::atomic::AtomicUsize;
20use std::sync::atomic::Ordering;
21use std::task::{Context, Poll};
22use std::time::Duration;
23use tokio::sync::oneshot;
24use tokio::time::Instant;
25use tokio::time::interval_at;
26use tracing::warn;
27
28type Registrations<V> = Vec<oneshot::Sender<V>>;
29
30struct TaskAbortOnDrop {
32 handle: Option<tokio::task::JoinHandle<()>>,
33}
34
35impl TaskAbortOnDrop {
36 fn new(handle: tokio::task::JoinHandle<()>) -> Self {
37 Self {
38 handle: Some(handle),
39 }
40 }
41}
42
43impl Drop for TaskAbortOnDrop {
44 fn drop(&mut self) {
45 if let Some(handle) = self.handle.take() {
46 handle.abort();
47 }
48 }
49}
50
51const LONG_WAIT_LOG_INTERVAL_SECS: u64 = 10;
53
54pub const CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME: &str =
55 "CheckpointBuilder::notify_read_executed_effects";
56
57pub struct NotifyRead<K, V> {
58 pending: Vec<Mutex<HashMap<K, Registrations<V>>>>,
59 count_pending: AtomicUsize,
60}
61
62impl<K: Eq + Hash + Clone, V: Clone> NotifyRead<K, V> {
63 pub fn new() -> Self {
64 let pending = (0..255).map(|_| Default::default()).collect();
65 let count_pending = Default::default();
66 Self {
67 pending,
68 count_pending,
69 }
70 }
71
72 pub fn notify(&self, key: &K, value: &V) -> usize {
74 let registrations = self.pending(key).remove(key);
75 let Some(registrations) = registrations else {
76 return self.count_pending.load(Ordering::Relaxed);
77 };
78 let rem = self
79 .count_pending
80 .fetch_sub(registrations.len(), Ordering::Relaxed);
81 for registration in registrations {
82 registration.send(value.clone()).ok();
83 }
84 rem
85 }
86
87 pub fn register_one(&self, key: &K) -> Registration<'_, K, V> {
88 self.count_pending.fetch_add(1, Ordering::Relaxed);
89 let (sender, receiver) = oneshot::channel();
90 self.register(key, sender);
91 Registration {
92 this: self,
93 registration: Some((key.clone(), receiver)),
94 }
95 }
96
97 pub fn register_all(&self, keys: &[K]) -> Vec<Registration<'_, K, V>> {
98 self.count_pending.fetch_add(keys.len(), Ordering::Relaxed);
99 let mut registrations = vec![];
100 for key in keys.iter() {
101 let (sender, receiver) = oneshot::channel();
102 self.register(key, sender);
103 let registration = Registration {
104 this: self,
105 registration: Some((key.clone(), receiver)),
106 };
107 registrations.push(registration);
108 }
109 registrations
110 }
111
112 fn register(&self, key: &K, sender: oneshot::Sender<V>) {
113 self.pending(key)
114 .entry(key.clone())
115 .or_default()
116 .push(sender);
117 }
118
119 fn pending(&self, key: &K) -> MutexGuard<'_, HashMap<K, Registrations<V>>> {
120 let mut state = DefaultHasher::new();
121 key.hash(&mut state);
122 let hash = state.finish();
123 let pending = self
124 .pending
125 .get((hash % self.pending.len() as u64) as usize)
126 .unwrap();
127 pending.lock()
128 }
129
130 pub fn num_pending(&self) -> usize {
131 self.count_pending.load(Ordering::Relaxed)
132 }
133
134 fn cleanup(&self, key: &K) {
135 let mut pending = self.pending(key);
136 let Some(registrations) = pending.get_mut(key) else {
138 return;
139 };
140 let mut count_deleted = 0usize;
141 registrations.retain(|s| {
142 let delete = s.is_closed();
143 if delete {
144 count_deleted += 1;
145 }
146 !delete
147 });
148 self.count_pending
149 .fetch_sub(count_deleted, Ordering::Relaxed);
150 if registrations.is_empty() {
151 pending.remove(key);
152 }
153 }
154}
155
156impl<K: Eq + Hash + Clone + Unpin + std::fmt::Debug + Send + Sync + 'static, V: Clone + Unpin>
157 NotifyRead<K, V>
158{
159 pub async fn read(
160 &self,
161 task_name: &'static str,
162 keys: &[K],
163 fetch: impl FnOnce(&[K]) -> Vec<Option<V>>,
164 ) -> Vec<V> {
165 let _metrics_scope = mysten_metrics::monitored_scope(task_name);
166 let registrations = self.register_all(keys);
167
168 let results = fetch(keys);
169
170 let waiting_keys: HashSet<K> = keys
172 .iter()
173 .zip_debug_eq(results.iter())
174 .filter(|&(_key, result)| result.is_none())
175 .map(|(key, _result)| key.clone())
176 .collect();
177 let has_waiting_keys = !waiting_keys.is_empty();
178 let waiting_keys = Arc::new(Mutex::new(waiting_keys));
179
180 let _log_handle_guard = if has_waiting_keys {
182 let waiting_keys_clone = waiting_keys.clone();
183 let start_time = Instant::now();
184 let task_name = task_name.to_string();
185
186 let handle = spawn_monitored_task!(async move {
187 let start = Instant::now() + Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS);
189 let mut interval =
190 interval_at(start, Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS));
191
192 loop {
193 interval.tick().await;
194 let current_waiting = waiting_keys_clone.lock();
195 if current_waiting.is_empty() {
196 break;
197 }
198 let keys_vec: Vec<_> = current_waiting.iter().cloned().collect();
199 drop(current_waiting); let elapsed_secs = start_time.elapsed().as_secs();
202
203 warn!(
204 "[{}] Still waiting for {}s for {} keys: {:?}",
205 task_name,
206 elapsed_secs,
207 keys_vec.len(),
208 keys_vec
209 );
210
211 if task_name == CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME && elapsed_secs >= 60 {
212 debug_fatal!("{} is stuck", task_name);
213 }
214 }
215 });
216 Some(TaskAbortOnDrop::new(handle))
217 } else {
218 None
219 };
220
221 let results = results
222 .into_iter()
223 .zip_debug_eq(registrations)
224 .zip_debug_eq(keys.iter())
225 .map(|((a, r), key)| match a {
226 Some(ready) => Either::Left(futures::future::ready(ready)),
228 None => {
229 let waiting_keys = waiting_keys.clone();
230 let key = key.clone();
231 Either::Right(async move {
232 let result = r.await;
233 waiting_keys.lock().remove(&key);
235 result
236 })
237 }
238 });
239
240 join_all(results).await
243 }
244}
245
246pub struct Registration<'a, K: Eq + Hash + Clone, V: Clone> {
249 this: &'a NotifyRead<K, V>,
250 registration: Option<(K, oneshot::Receiver<V>)>,
251}
252
253impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> Future for Registration<'_, K, V> {
254 type Output = V;
255
256 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257 let receiver = self
258 .registration
259 .as_mut()
260 .map(|(_key, receiver)| receiver)
261 .expect("poll can not be called after drop");
262 let poll = Pin::new(receiver).poll(cx);
263 if poll.is_ready() {
264 self.registration.take();
266 }
267 poll.map(|r| r.expect("Sender never drops when registration is pending"))
268 }
269}
270
271impl<K: Eq + Hash + Clone, V: Clone> Drop for Registration<'_, K, V> {
272 fn drop(&mut self) {
273 if let Some((key, receiver)) = self.registration.take() {
274 mem::drop(receiver);
275 self.this.cleanup(&key)
277 }
278 }
279}
280impl<K: Eq + Hash + Clone, V: Clone> Default for NotifyRead<K, V> {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use futures::future::join_all;
290 use std::sync::Arc;
291 use tokio::time::timeout;
292
293 #[tokio::test]
294 pub async fn test_notify_read() {
295 let notify_read = NotifyRead::<u64, u64>::new();
296 let mut registrations = notify_read.register_all(&[1, 2, 3]);
297 assert_eq!(3, notify_read.count_pending.load(Ordering::Relaxed));
298 registrations.pop();
299 assert_eq!(2, notify_read.count_pending.load(Ordering::Relaxed));
300 notify_read.notify(&2, &2);
301 notify_read.notify(&1, &1);
302 let reads = join_all(registrations).await;
303 assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
304 assert_eq!(reads, vec![1, 2]);
305 for pending in ¬ify_read.pending {
307 assert!(pending.lock().is_empty());
308 }
309 }
310
311 #[tokio::test]
312 pub async fn test_notify_read_cancellation() {
313 let notify_read = Arc::new(NotifyRead::<u64, u64>::new());
314
315 let read_future = notify_read.read(
317 "test_task",
318 &[1, 2, 3],
319 |_keys| vec![None, None, None], );
321
322 let result = timeout(Duration::from_millis(100), read_future).await;
324
325 assert!(result.is_err());
327
328 tokio::time::sleep(Duration::from_millis(50)).await;
330
331 assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
334
335 for pending in ¬ify_read.pending {
337 assert!(pending.lock().is_empty());
338 }
339 }
340}