mysten_common/sync/
notify_read.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::debug_fatal;
5
6use futures::future::{Either, join_all};
7use mysten_metrics::spawn_monitored_task;
8use parking_lot::Mutex;
9use parking_lot::MutexGuard;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::DefaultHasher;
13use std::future::Future;
14use std::hash::{Hash, Hasher};
15use std::mem;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::sync::atomic::AtomicUsize;
19use std::sync::atomic::Ordering;
20use std::task::{Context, Poll};
21use std::time::Duration;
22use tokio::sync::oneshot;
23use tokio::time::Instant;
24use tokio::time::interval_at;
25use tracing::warn;
26
27type Registrations<V> = Vec<oneshot::Sender<V>>;
28
29/// Wrapper that ensures a spawned task is aborted when dropped
30struct TaskAbortOnDrop {
31    handle: Option<tokio::task::JoinHandle<()>>,
32}
33
34impl TaskAbortOnDrop {
35    fn new(handle: tokio::task::JoinHandle<()>) -> Self {
36        Self {
37            handle: Some(handle),
38        }
39    }
40}
41
42impl Drop for TaskAbortOnDrop {
43    fn drop(&mut self) {
44        if let Some(handle) = self.handle.take() {
45            handle.abort();
46        }
47    }
48}
49
50/// Interval duration for logging waiting keys when reads take too long
51const LONG_WAIT_LOG_INTERVAL_SECS: u64 = 10;
52
53pub const CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME: &str =
54    "CheckpointBuilder::notify_read_executed_effects";
55
56pub struct NotifyRead<K, V> {
57    pending: Vec<Mutex<HashMap<K, Registrations<V>>>>,
58    count_pending: AtomicUsize,
59}
60
61impl<K: Eq + Hash + Clone, V: Clone> NotifyRead<K, V> {
62    pub fn new() -> Self {
63        let pending = (0..255).map(|_| Default::default()).collect();
64        let count_pending = Default::default();
65        Self {
66            pending,
67            count_pending,
68        }
69    }
70
71    /// Asynchronously notifies waiters and return number of remaining pending registration
72    pub fn notify(&self, key: &K, value: &V) -> usize {
73        let registrations = self.pending(key).remove(key);
74        let Some(registrations) = registrations else {
75            return self.count_pending.load(Ordering::Relaxed);
76        };
77        let rem = self
78            .count_pending
79            .fetch_sub(registrations.len(), Ordering::Relaxed);
80        for registration in registrations {
81            registration.send(value.clone()).ok();
82        }
83        rem
84    }
85
86    pub fn register_one(&self, key: &K) -> Registration<'_, K, V> {
87        self.count_pending.fetch_add(1, Ordering::Relaxed);
88        let (sender, receiver) = oneshot::channel();
89        self.register(key, sender);
90        Registration {
91            this: self,
92            registration: Some((key.clone(), receiver)),
93        }
94    }
95
96    pub fn register_all(&self, keys: &[K]) -> Vec<Registration<'_, K, V>> {
97        self.count_pending.fetch_add(keys.len(), Ordering::Relaxed);
98        let mut registrations = vec![];
99        for key in keys.iter() {
100            let (sender, receiver) = oneshot::channel();
101            self.register(key, sender);
102            let registration = Registration {
103                this: self,
104                registration: Some((key.clone(), receiver)),
105            };
106            registrations.push(registration);
107        }
108        registrations
109    }
110
111    fn register(&self, key: &K, sender: oneshot::Sender<V>) {
112        self.pending(key)
113            .entry(key.clone())
114            .or_default()
115            .push(sender);
116    }
117
118    fn pending(&self, key: &K) -> MutexGuard<'_, HashMap<K, Registrations<V>>> {
119        let mut state = DefaultHasher::new();
120        key.hash(&mut state);
121        let hash = state.finish();
122        let pending = self
123            .pending
124            .get((hash % self.pending.len() as u64) as usize)
125            .unwrap();
126        pending.lock()
127    }
128
129    pub fn num_pending(&self) -> usize {
130        self.count_pending.load(Ordering::Relaxed)
131    }
132
133    fn cleanup(&self, key: &K) {
134        let mut pending = self.pending(key);
135        // it is possible that registration was fulfilled before we get here
136        let Some(registrations) = pending.get_mut(key) else {
137            return;
138        };
139        let mut count_deleted = 0usize;
140        registrations.retain(|s| {
141            let delete = s.is_closed();
142            if delete {
143                count_deleted += 1;
144            }
145            !delete
146        });
147        self.count_pending
148            .fetch_sub(count_deleted, Ordering::Relaxed);
149        if registrations.is_empty() {
150            pending.remove(key);
151        }
152    }
153}
154
155impl<K: Eq + Hash + Clone + Unpin + std::fmt::Debug + Send + Sync + 'static, V: Clone + Unpin>
156    NotifyRead<K, V>
157{
158    pub async fn read(
159        &self,
160        task_name: &'static str,
161        keys: &[K],
162        fetch: impl FnOnce(&[K]) -> Vec<Option<V>>,
163    ) -> Vec<V> {
164        let _metrics_scope = mysten_metrics::monitored_scope(task_name);
165        let registrations = self.register_all(keys);
166
167        let results = fetch(keys);
168
169        // Track which keys are still waiting
170        let waiting_keys: HashSet<K> = keys
171            .iter()
172            .zip(results.iter())
173            .filter(|&(_key, result)| result.is_none())
174            .map(|(key, _result)| key.clone())
175            .collect();
176        let has_waiting_keys = !waiting_keys.is_empty();
177        let waiting_keys = Arc::new(Mutex::new(waiting_keys));
178
179        // Spawn logging task if there are waiting keys
180        let _log_handle_guard = if has_waiting_keys {
181            let waiting_keys_clone = waiting_keys.clone();
182            let start_time = Instant::now();
183            let task_name = task_name.to_string();
184
185            let handle = spawn_monitored_task!(async move {
186                // Only start logging after the first interval.
187                let start = Instant::now() + Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS);
188                let mut interval =
189                    interval_at(start, Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS));
190
191                loop {
192                    interval.tick().await;
193                    let current_waiting = waiting_keys_clone.lock();
194                    if current_waiting.is_empty() {
195                        break;
196                    }
197                    let keys_vec: Vec<_> = current_waiting.iter().cloned().collect();
198                    drop(current_waiting); // Release lock before logging
199
200                    let elapsed_secs = start_time.elapsed().as_secs();
201
202                    warn!(
203                        "[{}] Still waiting for {}s for {} keys: {:?}",
204                        task_name,
205                        elapsed_secs,
206                        keys_vec.len(),
207                        keys_vec
208                    );
209
210                    if task_name == CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME && elapsed_secs >= 60 {
211                        debug_fatal!("{} is stuck", task_name);
212                    }
213                }
214            });
215            Some(TaskAbortOnDrop::new(handle))
216        } else {
217            None
218        };
219
220        let results =
221            results
222                .into_iter()
223                .zip(registrations)
224                .zip(keys.iter())
225                .map(|((a, r), key)| match a {
226                    // Note that Some() clause also drops registration that is already fulfilled
227                    Some(ready) => Either::Left(futures::future::ready(ready)),
228                    None => {
229                        let waiting_keys = waiting_keys.clone();
230                        let key = key.clone();
231                        Either::Right(async move {
232                            let result = r.await;
233                            // Remove this key from the waiting set
234                            waiting_keys.lock().remove(&key);
235                            result
236                        })
237                    }
238                });
239
240        // The logging task will be automatically aborted when _log_handle_guard is dropped
241
242        join_all(results).await
243    }
244}
245
246/// Registration resolves to the value but also provides safe cancellation
247/// When Registration is dropped before it is resolved, we de-register from the pending list
248pub struct Registration<'a, K: Eq + Hash + Clone, V: Clone> {
249    this: &'a NotifyRead<K, V>,
250    registration: Option<(K, oneshot::Receiver<V>)>,
251}
252
253impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> Future for Registration<'_, K, V> {
254    type Output = V;
255
256    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257        let receiver = self
258            .registration
259            .as_mut()
260            .map(|(_key, receiver)| receiver)
261            .expect("poll can not be called after drop");
262        let poll = Pin::new(receiver).poll(cx);
263        if poll.is_ready() {
264            // When polling complete we no longer need to cancel
265            self.registration.take();
266        }
267        poll.map(|r| r.expect("Sender never drops when registration is pending"))
268    }
269}
270
271impl<K: Eq + Hash + Clone, V: Clone> Drop for Registration<'_, K, V> {
272    fn drop(&mut self) {
273        if let Some((key, receiver)) = self.registration.take() {
274            mem::drop(receiver);
275            // Receiver is dropped before cleanup
276            self.this.cleanup(&key)
277        }
278    }
279}
280impl<K: Eq + Hash + Clone, V: Clone> Default for NotifyRead<K, V> {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use futures::future::join_all;
290    use std::sync::Arc;
291    use tokio::time::timeout;
292
293    #[tokio::test]
294    pub async fn test_notify_read() {
295        let notify_read = NotifyRead::<u64, u64>::new();
296        let mut registrations = notify_read.register_all(&[1, 2, 3]);
297        assert_eq!(3, notify_read.count_pending.load(Ordering::Relaxed));
298        registrations.pop();
299        assert_eq!(2, notify_read.count_pending.load(Ordering::Relaxed));
300        notify_read.notify(&2, &2);
301        notify_read.notify(&1, &1);
302        let reads = join_all(registrations).await;
303        assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
304        assert_eq!(reads, vec![1, 2]);
305        // ensure cleanup is done correctly
306        for pending in &notify_read.pending {
307            assert!(pending.lock().is_empty());
308        }
309    }
310
311    #[tokio::test]
312    pub async fn test_notify_read_cancellation() {
313        let notify_read = Arc::new(NotifyRead::<u64, u64>::new());
314
315        // Start a read that will wait indefinitely
316        let read_future = notify_read.read(
317            "test_task",
318            &[1, 2, 3],
319            |_keys| vec![None, None, None], // All keys will wait
320        );
321
322        // Use timeout to cancel the read after a short duration
323        let result = timeout(Duration::from_millis(100), read_future).await;
324
325        // Verify the read was cancelled
326        assert!(result.is_err());
327
328        // Give some time for cleanup to complete
329        tokio::time::sleep(Duration::from_millis(50)).await;
330
331        // When the read is cancelled, the registrations are cleaned up
332        // so the pending count should be 0
333        assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
334
335        // Verify all pending maps are empty (cleanup was performed)
336        for pending in &notify_read.pending {
337            assert!(pending.lock().is_empty());
338        }
339    }
340}