1use crate::debug_fatal;
5
6use futures::future::{Either, join_all};
7use mysten_metrics::spawn_monitored_task;
8use parking_lot::Mutex;
9use parking_lot::MutexGuard;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::hash_map::DefaultHasher;
13use std::future::Future;
14use std::hash::{Hash, Hasher};
15use std::mem;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::sync::atomic::AtomicUsize;
19use std::sync::atomic::Ordering;
20use std::task::{Context, Poll};
21use std::time::Duration;
22use tokio::sync::oneshot;
23use tokio::time::Instant;
24use tokio::time::interval_at;
25use tracing::warn;
26
27type Registrations<V> = Vec<oneshot::Sender<V>>;
28
29struct TaskAbortOnDrop {
31 handle: Option<tokio::task::JoinHandle<()>>,
32}
33
34impl TaskAbortOnDrop {
35 fn new(handle: tokio::task::JoinHandle<()>) -> Self {
36 Self {
37 handle: Some(handle),
38 }
39 }
40}
41
42impl Drop for TaskAbortOnDrop {
43 fn drop(&mut self) {
44 if let Some(handle) = self.handle.take() {
45 handle.abort();
46 }
47 }
48}
49
50const LONG_WAIT_LOG_INTERVAL_SECS: u64 = 10;
52
53pub const CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME: &str =
54 "CheckpointBuilder::notify_read_executed_effects";
55
56pub struct NotifyRead<K, V> {
57 pending: Vec<Mutex<HashMap<K, Registrations<V>>>>,
58 count_pending: AtomicUsize,
59}
60
61impl<K: Eq + Hash + Clone, V: Clone> NotifyRead<K, V> {
62 pub fn new() -> Self {
63 let pending = (0..255).map(|_| Default::default()).collect();
64 let count_pending = Default::default();
65 Self {
66 pending,
67 count_pending,
68 }
69 }
70
71 pub fn notify(&self, key: &K, value: &V) -> usize {
73 let registrations = self.pending(key).remove(key);
74 let Some(registrations) = registrations else {
75 return self.count_pending.load(Ordering::Relaxed);
76 };
77 let rem = self
78 .count_pending
79 .fetch_sub(registrations.len(), Ordering::Relaxed);
80 for registration in registrations {
81 registration.send(value.clone()).ok();
82 }
83 rem
84 }
85
86 pub fn register_one(&self, key: &K) -> Registration<'_, K, V> {
87 self.count_pending.fetch_add(1, Ordering::Relaxed);
88 let (sender, receiver) = oneshot::channel();
89 self.register(key, sender);
90 Registration {
91 this: self,
92 registration: Some((key.clone(), receiver)),
93 }
94 }
95
96 pub fn register_all(&self, keys: &[K]) -> Vec<Registration<'_, K, V>> {
97 self.count_pending.fetch_add(keys.len(), Ordering::Relaxed);
98 let mut registrations = vec![];
99 for key in keys.iter() {
100 let (sender, receiver) = oneshot::channel();
101 self.register(key, sender);
102 let registration = Registration {
103 this: self,
104 registration: Some((key.clone(), receiver)),
105 };
106 registrations.push(registration);
107 }
108 registrations
109 }
110
111 fn register(&self, key: &K, sender: oneshot::Sender<V>) {
112 self.pending(key)
113 .entry(key.clone())
114 .or_default()
115 .push(sender);
116 }
117
118 fn pending(&self, key: &K) -> MutexGuard<'_, HashMap<K, Registrations<V>>> {
119 let mut state = DefaultHasher::new();
120 key.hash(&mut state);
121 let hash = state.finish();
122 let pending = self
123 .pending
124 .get((hash % self.pending.len() as u64) as usize)
125 .unwrap();
126 pending.lock()
127 }
128
129 pub fn num_pending(&self) -> usize {
130 self.count_pending.load(Ordering::Relaxed)
131 }
132
133 fn cleanup(&self, key: &K) {
134 let mut pending = self.pending(key);
135 let Some(registrations) = pending.get_mut(key) else {
137 return;
138 };
139 let mut count_deleted = 0usize;
140 registrations.retain(|s| {
141 let delete = s.is_closed();
142 if delete {
143 count_deleted += 1;
144 }
145 !delete
146 });
147 self.count_pending
148 .fetch_sub(count_deleted, Ordering::Relaxed);
149 if registrations.is_empty() {
150 pending.remove(key);
151 }
152 }
153}
154
155impl<K: Eq + Hash + Clone + Unpin + std::fmt::Debug + Send + Sync + 'static, V: Clone + Unpin>
156 NotifyRead<K, V>
157{
158 pub async fn read(
159 &self,
160 task_name: &'static str,
161 keys: &[K],
162 fetch: impl FnOnce(&[K]) -> Vec<Option<V>>,
163 ) -> Vec<V> {
164 let _metrics_scope = mysten_metrics::monitored_scope(task_name);
165 let registrations = self.register_all(keys);
166
167 let results = fetch(keys);
168
169 let waiting_keys: HashSet<K> = keys
171 .iter()
172 .zip(results.iter())
173 .filter(|&(_key, result)| result.is_none())
174 .map(|(key, _result)| key.clone())
175 .collect();
176 let has_waiting_keys = !waiting_keys.is_empty();
177 let waiting_keys = Arc::new(Mutex::new(waiting_keys));
178
179 let _log_handle_guard = if has_waiting_keys {
181 let waiting_keys_clone = waiting_keys.clone();
182 let start_time = Instant::now();
183 let task_name = task_name.to_string();
184
185 let handle = spawn_monitored_task!(async move {
186 let start = Instant::now() + Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS);
188 let mut interval =
189 interval_at(start, Duration::from_secs(LONG_WAIT_LOG_INTERVAL_SECS));
190
191 loop {
192 interval.tick().await;
193 let current_waiting = waiting_keys_clone.lock();
194 if current_waiting.is_empty() {
195 break;
196 }
197 let keys_vec: Vec<_> = current_waiting.iter().cloned().collect();
198 drop(current_waiting); let elapsed_secs = start_time.elapsed().as_secs();
201
202 warn!(
203 "[{}] Still waiting for {}s for {} keys: {:?}",
204 task_name,
205 elapsed_secs,
206 keys_vec.len(),
207 keys_vec
208 );
209
210 if task_name == CHECKPOINT_BUILDER_NOTIFY_READ_TASK_NAME && elapsed_secs >= 60 {
211 debug_fatal!("{} is stuck", task_name);
212 }
213 }
214 });
215 Some(TaskAbortOnDrop::new(handle))
216 } else {
217 None
218 };
219
220 let results =
221 results
222 .into_iter()
223 .zip(registrations)
224 .zip(keys.iter())
225 .map(|((a, r), key)| match a {
226 Some(ready) => Either::Left(futures::future::ready(ready)),
228 None => {
229 let waiting_keys = waiting_keys.clone();
230 let key = key.clone();
231 Either::Right(async move {
232 let result = r.await;
233 waiting_keys.lock().remove(&key);
235 result
236 })
237 }
238 });
239
240 join_all(results).await
243 }
244}
245
246pub struct Registration<'a, K: Eq + Hash + Clone, V: Clone> {
249 this: &'a NotifyRead<K, V>,
250 registration: Option<(K, oneshot::Receiver<V>)>,
251}
252
253impl<K: Eq + Hash + Clone + Unpin, V: Clone + Unpin> Future for Registration<'_, K, V> {
254 type Output = V;
255
256 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257 let receiver = self
258 .registration
259 .as_mut()
260 .map(|(_key, receiver)| receiver)
261 .expect("poll can not be called after drop");
262 let poll = Pin::new(receiver).poll(cx);
263 if poll.is_ready() {
264 self.registration.take();
266 }
267 poll.map(|r| r.expect("Sender never drops when registration is pending"))
268 }
269}
270
271impl<K: Eq + Hash + Clone, V: Clone> Drop for Registration<'_, K, V> {
272 fn drop(&mut self) {
273 if let Some((key, receiver)) = self.registration.take() {
274 mem::drop(receiver);
275 self.this.cleanup(&key)
277 }
278 }
279}
280impl<K: Eq + Hash + Clone, V: Clone> Default for NotifyRead<K, V> {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use futures::future::join_all;
290 use std::sync::Arc;
291 use tokio::time::timeout;
292
293 #[tokio::test]
294 pub async fn test_notify_read() {
295 let notify_read = NotifyRead::<u64, u64>::new();
296 let mut registrations = notify_read.register_all(&[1, 2, 3]);
297 assert_eq!(3, notify_read.count_pending.load(Ordering::Relaxed));
298 registrations.pop();
299 assert_eq!(2, notify_read.count_pending.load(Ordering::Relaxed));
300 notify_read.notify(&2, &2);
301 notify_read.notify(&1, &1);
302 let reads = join_all(registrations).await;
303 assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
304 assert_eq!(reads, vec![1, 2]);
305 for pending in ¬ify_read.pending {
307 assert!(pending.lock().is_empty());
308 }
309 }
310
311 #[tokio::test]
312 pub async fn test_notify_read_cancellation() {
313 let notify_read = Arc::new(NotifyRead::<u64, u64>::new());
314
315 let read_future = notify_read.read(
317 "test_task",
318 &[1, 2, 3],
319 |_keys| vec![None, None, None], );
321
322 let result = timeout(Duration::from_millis(100), read_future).await;
324
325 assert!(result.is_err());
327
328 tokio::time::sleep(Duration::from_millis(50)).await;
330
331 assert_eq!(0, notify_read.count_pending.load(Ordering::Relaxed));
334
335 for pending in ¬ify_read.pending {
337 assert!(pending.lock().is_empty());
338 }
339 }
340}