mysten_common/sync/
notify_read.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::ZipDebugEqIteratorExt;
5use crate::debug_fatal;
6
7use futures::future::{Either, join_all};
8use mysten_metrics::spawn_monitored_task;
9use parking_lot::Mutex;
10use parking_lot::MutexGuard;
11use std::collections::HashMap;
12use std::collections::HashSet;
13use std::collections::hash_map::DefaultHasher;
14use std::future::Future;
15use std::hash::{Hash, Hasher};
16use std::mem;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::sync::atomic::AtomicUsize;
20use std::sync::atomic::Ordering;
21use std::task::{Context, Poll};
22use std::time::Duration;
23use tokio::sync::oneshot;
24use tokio::time::Instant;
25use tokio::time::interval_at;
26use tracing::warn;
27
28type Registrations<V> = Vec<oneshot::Sender<V>>;
29
30/// Wrapper that ensures a spawned task is aborted when dropped
31struct TaskAbortOnDrop {
32    handle: Option<tokio::task::JoinHandle<()>>,
33}
34
35impl TaskAbortOnDrop {
36    fn new(handle: tokio::task::JoinHandle<()>) -> Self {
37        Self {
38            handle: Some(handle),
39        }
40    }
41}
42
43impl Drop for TaskAbortOnDrop {
44    fn drop(&mut self) {
45        if let Some(handle) = self.handle.take() {
46            handle.abort();
47        }
48    }
49}
50
51/// Interval duration for logging waiting keys when reads take too long
52const LONG_WAIT_LOG_INTERVAL_SECS: u64 = 10;
53
54pub const CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME: &str =
55    "CheckpointBuilder::notify_read_executed_effects";
56
57pub struct NotifyRead<K, V> {
58    pending: Vec<Mutex<HashMap<K, Registrations<V>>>>,
59    count_pending: AtomicUsize,
60}
61
62impl<K: Eq + Hash + Clone, V: Clone> NotifyRead<K, V> {
63    pub fn new() -> Self {
64        let pending = (0..255).map(|_| Default::default()).collect();
65        let count_pending = Default::default();
66        Self {
67            pending,
68            count_pending,
69        }
70    }
71
72    /// Asynchronously notifies waiters and return number of remaining pending registration
73    pub fn notify(&self, key: &K, value: &V) -> usize {
74        let registrations = self.pending(key).remove(key);
75        let Some(registrations) = registrations else {
76            return self.count_pending.load(Ordering::Relaxed);
77        };
78        let rem = self
79            .count_pending
80            .fetch_sub(registrations.len(), Ordering::Relaxed);
81        for registration in registrations {
82            registration.send(value.clone()).ok();
83        }
84        rem
85    }
86
87    pub fn register_one(&self, key: &K) -> Registration<'_, K, V> {
88        self.count_pending.fetch_add(1, Ordering::Relaxed);
89        let (sender, receiver) = oneshot::channel();
90        self.register(key, sender);
91        Registration {
92            this: self,
93            registration: Some((key.clone(), receiver)),
94        }
95    }
96
97    pub fn register_all(&self, keys: &[K]) -> Vec<Registration<'_, K, V>> {
98        self.count_pending.fetch_add(keys.len(), Ordering::Relaxed);
99        let mut registrations = vec![];
100        for key in keys.iter() {
101            let (sender, receiver) = oneshot::channel();
102            self.register(key, sender);
103            let registration = Registration {
104                this: self,
105                registration: Some((key.clone(), receiver)),
106            };
107            registrations.push(registration);
108        }
109        registrations
110    }
111
112    fn register(&self, key: &K, sender: oneshot::Sender<V>) {
113        self.pending(key)
114            .entry(key.clone())
115            .or_default()
116            .push(sender);
117    }
118
119    fn pending(&self, key: &K) -> MutexGuard<'_, HashMap<K, Registrations<V>>> {
120        let mut state = DefaultHasher::new();
121        key.hash(&mut state);
122        let hash = state.finish();
123        let pending = self
124            .pending
125            .get((hash % self.pending.len() as u64) as usize)
126            .unwrap();
127        pending.lock()
128    }
129
130    pub fn num_pending(&self) -> usize {
131        self.count_pending.load(Ordering::Relaxed)
132    }
133
134    fn cleanup(&self, key: &K) {
135        let mut pending = self.pending(key);
136        // it is possible that registration was fulfilled before we get here
137        let Some(registrations) = pending.get_mut(key) else {
138            return;
139        };
140        let mut count_deleted = 0usize;
141        registrations.retain(|s| {
142            let delete = s.is_closed();
143            if delete {
144                count_deleted += 1;
145            }
146            !delete
147        });
148        self.count_pending
149            .fetch_sub(count_deleted, Ordering::Relaxed);
150        if registrations.is_empty() {
151            pending.remove(key);
152        }
153    }
154}
155
156impl<K: Eq + Hash + Clone + Unpin + std::fmt::Debug + Send + Sync + 'static, V: Clone + Unpin>
157    NotifyRead<K, V>
158{
159    pub async fn read(
160        &self,
161        task_name: &'static str,
162        keys: &[K],
163        fetch: impl FnOnce(&[K]) -> Vec<Option<V>>,
164    ) -> Vec<V> {
165        let _metrics_scope = mysten_metrics::monitored_scope(task_name);
166        let registrations = self.register_all(keys);
167
168        let results = fetch(keys);
169
170        // Track which keys are still waiting
171        let waiting_keys: HashSet<K> = keys
172            .iter()
173            .zip_debug_eq(results.iter())
174            .filter(|&(_key, result)| result.is_none())
175            .map(|(key, _result)| key.clone())
176            .collect();
177        let has_waiting_keys = !waiting_keys.is_empty();
178        let waiting_keys = Arc::new(Mutex::new(waiting_keys));
179
180        // Spawn logging task if there are waiting keys
181        let _log_handle_guard = if has_waiting_keys {
182            let waiting_keys_clone = waiting_keys.clone();
183            let start_time = Instant::now();
184            let task_name = task_name.to_string();
185
186            let handle = spawn_monitored_task!(async move {
187                // Only start logging after the first interval.
188                let start = Instant::now() + Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS);
189                let mut interval =
190                    interval_at(start, Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS));
191
192                loop {
193                    interval.tick().await;
194                    let current_waiting = waiting_keys_clone.lock();
195                    if current_waiting.is_empty() {
196                        break;
197                    }
198                    let keys_vec: Vec<_> = current_waiting.iter().cloned().collect();
199                    drop(current_waiting); // Release lock before logging
200
201                    let elapsed_secs = start_time.elapsed().as_secs();
202
203                    warn!(
204                        "[{}] Still waiting for {}s for {} keys: {:?}",
205                        task_name,
206                        elapsed_secs,
207                        keys_vec.len(),
208                        keys_vec
209                    );
210
211                    if task_name == CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME && elapsed_secs >= 60 {
212                        debug_fatal!("{} is stuck", task_name);
213                    }
214                }
215            });
216            Some(TaskAbortOnDrop::new(handle))
217        } else {
218            None
219        };
220
221        let results = results
222            .into_iter()
223            .zip_debug_eq(registrations)
224            .zip_debug_eq(keys.iter())
225            .map(|((a, r), key)| match a {
226                // Note that Some() clause also drops registration that is already fulfilled
227                Some(ready) => Either::Left(futures::future::ready(ready)),
228                None => {
229                    let waiting_keys = waiting_keys.clone();
230                    let key = key.clone();
231                    Either::Right(async move {
232                        let result = r.await;
233                        // Remove this key from the waiting set
234                        waiting_keys.lock().remove(&key);
235                        result
236                    })
237                }
238            });
239
240        // The logging task will be automatically aborted when _log_handle_guard is dropped
241
242        join_all(results).await
243    }
244}
245
246/// Registration resolves to the value but also provides safe cancellation
247/// When Registration is dropped before it is resolved, we de-register from the pending list
248pub struct Registration<'a, K: Eq + Hash + Clone, V: Clone> {
249    this: &'a NotifyRead<K, V>,
250    registration: Option<(K, oneshot::Receiver<V>)>,
251}
252
253impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> Future for Registration<'_, K, V> {
254    type Output = V;
255
256    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257        let receiver = self
258            .registration
259            .as_mut()
260            .map(|(_key, receiver)| receiver)
261            .expect("poll can not be called after drop");
262        let poll = Pin::new(receiver).poll(cx);
263        if poll.is_ready() {
264            // When polling complete we no longer need to cancel
265            self.registration.take();
266        }
267        poll.map(|r| r.expect("Sender never drops when registration is pending"))
268    }
269}
270
271impl<K: Eq + Hash + Clone, V: Clone> Drop for Registration<'_, K, V> {
272    fn drop(&mut self) {
273        if let Some((key, receiver)) = self.registration.take() {
274            mem::drop(receiver);
275            // Receiver is dropped before cleanup
276            self.this.cleanup(&key)
277        }
278    }
279}
280impl<K: Eq + Hash + Clone, V: Clone> Default for NotifyRead<K, V> {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use futures::future::join_all;
290    use std::sync::Arc;
291    use tokio::time::timeout;
292
293    #[tokio::test]
294    pub async fn test_notify_read() {
295        let notify_read = NotifyRead::<u64, u64>::new();
296        let mut registrations = notify_read.register_all(&[1, 2, 3]);
297        assert_eq!(3, notify_read.count_pending.load(Ordering::Relaxed));
298        registrations.pop();
299        assert_eq!(2, notify_read.count_pending.load(Ordering::Relaxed));
300        notify_read.notify(&2, &2);
301        notify_read.notify(&1, &1);
302        let reads = join_all(registrations).await;
303        assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
304        assert_eq!(reads, vec![1, 2]);
305        // ensure cleanup is done correctly
306        for pending in &notify_read.pending {
307            assert!(pending.lock().is_empty());
308        }
309    }
310
311    #[tokio::test]
312    pub async fn test_notify_read_cancellation() {
313        let notify_read = Arc::new(NotifyRead::<u64, u64>::new());
314
315        // Start a read that will wait indefinitely
316        let read_future = notify_read.read(
317            "test_task",
318            &[1, 2, 3],
319            |_keys| vec![None, None, None], // All keys will wait
320        );
321
322        // Use timeout to cancel the read after a short duration
323        let result = timeout(Duration::from_millis(100), read_future).await;
324
325        // Verify the read was cancelled
326        assert!(result.is_err());
327
328        // Give some time for cleanup to complete
329        tokio::time::sleep(Duration::from_millis(50)).await;
330
331        // When the read is cancelled, the registrations are cleaned up
332        // so the pending count should be 0
333        assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
334
335        // Verify all pending maps are empty (cleanup was performed)
336        for pending in &notify_read.pending {
337            assert!(pending.lock().is_empty());
338        }
339    }
340}