1use std::collections::HashMap;
5use std::collections::hash_map::{DefaultHasher, RandomState};
6use std::error::Error;
7use std::fmt;
8use std::hash::{BuildHasher, Hash, Hasher};
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use std::time::Duration;
12
13use parking_lot::{ArcMutexGuard, ArcRwLockReadGuard, ArcRwLockWriteGuard, Mutex, RwLock};
14use tokio::task::JoinHandle;
15use tokio::time::Instant;
16use tracing::info;
17
18use mysten_metrics::spawn_monitored_task;
19
20type OwnedMutexGuard<T> = ArcMutexGuard<parking_lot::RawMutex, T>;
21type OwnedRwLockReadGuard<T> = ArcRwLockReadGuard<parking_lot::RawRwLock, T>;
22type OwnedRwLockWriteGuard<T> = ArcRwLockWriteGuard<parking_lot::RawRwLock, T>;
23
24pub trait Lock: Send + Sync + Default {
25 type Guard;
26 type ReadGuard;
27 fn lock_owned(self: Arc<Self>) -> Self::Guard;
28 fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard>;
29 fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard;
30}
31
32impl Lock for Mutex<()> {
33 type Guard = OwnedMutexGuard<()>;
34 type ReadGuard = Self::Guard;
35
36 fn lock_owned(self: Arc<Self>) -> Self::Guard {
37 self.lock_arc()
38 }
39
40 fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
41 self.try_lock_arc()
42 }
43
44 fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
45 self.lock_arc()
46 }
47}
48
49impl Lock for RwLock<()> {
50 type Guard = OwnedRwLockWriteGuard<()>;
51 type ReadGuard = OwnedRwLockReadGuard<()>;
52
53 fn lock_owned(self: Arc<Self>) -> Self::Guard {
54 self.write_arc()
55 }
56
57 fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
58 self.try_write_arc()
59 }
60
61 fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
62 self.read_arc()
63 }
64}
65
66type InnerLockTable<K, L> = HashMap<K, Arc<L>>;
67pub struct LockTable<K: Hash, L: Lock> {
69 random_state: RandomState,
70 lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>,
71 _k: std::marker::PhantomData<K>,
72 _cleaner: JoinHandle<()>,
73 stop: Arc<AtomicBool>,
74 size: Arc<AtomicUsize>,
75}
76
77pub type MutexTable<K> = LockTable<K, Mutex<()>>;
78pub type RwLockTable<K> = LockTable<K, RwLock<()>>;
79
80#[derive(Debug)]
81pub enum TryAcquireLockError {
82 LockTableLocked,
83 LockEntryLocked,
84}
85
86impl fmt::Display for TryAcquireLockError {
87 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(fmt, "operation would block")
89 }
90}
91
92impl Error for TryAcquireLockError {}
93pub type MutexGuard = OwnedMutexGuard<()>;
94pub type RwLockGuard = OwnedRwLockReadGuard<()>;
95
96impl<K: Hash + Eq + Send + Sync + 'static, L: Lock + 'static> LockTable<K, L> {
97 pub fn new_with_cleanup(
98 num_shards: usize,
99 cleanup_period: Duration,
100 cleanup_initial_delay: Duration,
101 cleanup_entries_threshold: usize,
102 ) -> Self {
103 let num_shards = if cfg!(msim) { 4 } else { num_shards };
104
105 let lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>> = Arc::new(
106 (0..num_shards)
107 .map(|_| RwLock::new(HashMap::new()))
108 .collect(),
109 );
110 let cloned = lock_table.clone();
111 let stop = Arc::new(AtomicBool::new(false));
112 let stop_cloned = stop.clone();
113 let size: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
114 let size_cloned = size.clone();
115 Self {
116 random_state: RandomState::new(),
117 lock_table,
118 _k: std::marker::PhantomData {},
119 _cleaner: spawn_monitored_task!(async move {
120 tokio::time::sleep(cleanup_initial_delay).await;
121 let mut previous_cleanup_instant = Instant::now();
122 while !stop_cloned.load(Ordering::SeqCst) {
123 if size_cloned.load(Ordering::SeqCst) >= cleanup_entries_threshold
124 || previous_cleanup_instant.elapsed() >= cleanup_period
125 {
126 let num_removed = Self::cleanup(cloned.clone());
127 size_cloned.fetch_sub(num_removed, Ordering::SeqCst);
128 previous_cleanup_instant = Instant::now();
129 }
130 tokio::time::sleep(Duration::from_secs(1)).await;
131 }
132 info!("Stopping mutex table cleanup!");
133 }),
134 stop,
135 size,
136 }
137 }
138
139 pub fn new(num_shards: usize) -> Self {
140 Self::new_with_cleanup(
141 num_shards,
142 Duration::from_secs(10),
143 Duration::from_secs(10),
144 10_000,
145 )
146 }
147
148 pub fn size(&self) -> usize {
149 self.size.load(Ordering::SeqCst)
150 }
151
152 pub fn cleanup(lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>) -> usize {
153 let mut num_removed: usize = 0;
154 for shard in lock_table.iter() {
155 let map = shard.try_write();
156 if map.is_none() {
157 continue;
158 }
159 map.unwrap().retain(|_k, v| {
160 if Arc::strong_count(v) == 1 {
164 num_removed += 1;
165 false
166 } else {
167 true
168 }
169 });
170 }
171 num_removed
172 }
173
174 fn get_lock_idx(&self, key: &K) -> usize {
175 let mut hasher = if !cfg!(test) {
176 self.random_state.build_hasher()
177 } else {
178 DefaultHasher::new()
180 };
181
182 key.hash(&mut hasher);
183 let hash: usize = hasher.finish().try_into().unwrap();
185 hash % self.lock_table.len()
186 }
187
188 pub fn acquire_locks<I>(&self, object_iter: I) -> Vec<L::Guard>
189 where
190 I: Iterator<Item = K>,
191 K: Ord,
192 {
193 let mut objects: Vec<K> = object_iter.into_iter().collect();
194 objects.sort_unstable();
195 objects.dedup();
196
197 let mut guards = Vec::with_capacity(objects.len());
198 for object in objects.into_iter() {
199 guards.push(self.acquire_lock(object));
200 }
201 guards
202 }
203
204 pub fn acquire_read_locks(&self, mut objects: Vec<K>) -> Vec<L::ReadGuard>
205 where
206 K: Ord,
207 {
208 objects.sort_unstable();
209 objects.dedup();
210 let mut guards = Vec::with_capacity(objects.len());
211 for object in objects.into_iter() {
212 guards.push(self.get_lock(object).read_lock_owned());
213 }
214 guards
215 }
216
217 pub fn get_lock(&self, k: K) -> Arc<L> {
218 let lock_idx = self.get_lock_idx(&k);
219 let element = {
220 let map = self.lock_table[lock_idx].read();
221 map.get(&k).cloned()
222 };
223 if let Some(element) = element {
224 element
225 } else {
226 {
229 let mut map = self.lock_table[lock_idx].write();
230 map.entry(k)
231 .or_insert_with(|| {
232 self.size.fetch_add(1, Ordering::SeqCst);
233 Arc::new(L::default())
234 })
235 .clone()
236 }
237 }
238 }
239
240 pub fn acquire_lock(&self, k: K) -> L::Guard {
241 self.get_lock(k).lock_owned()
242 }
243
244 pub fn try_acquire_lock(&self, k: K) -> Result<L::Guard, TryAcquireLockError> {
245 let lock_idx = self.get_lock_idx(&k);
246 let element = {
247 let map = self.lock_table[lock_idx]
248 .try_read()
249 .ok_or(TryAcquireLockError::LockTableLocked)?;
250 map.get(&k).cloned()
251 };
252 if let Some(element) = element {
253 let lock = element.try_lock_owned();
254 lock.ok_or(TryAcquireLockError::LockEntryLocked)
255 } else {
256 let element = {
258 let mut map = self.lock_table[lock_idx]
259 .try_write()
260 .ok_or(TryAcquireLockError::LockTableLocked)?;
261 map.entry(k)
262 .or_insert_with(|| {
263 self.size.fetch_add(1, Ordering::SeqCst);
264 Arc::new(L::default())
265 })
266 .clone()
267 };
268 let lock = element.try_lock_owned();
269 lock.ok_or(TryAcquireLockError::LockEntryLocked)
270 }
271 }
272}
273
274impl<K: Hash, L: Lock> Drop for LockTable<K, L> {
275 fn drop(&mut self) {
276 self.stop.store(true, Ordering::SeqCst);
277 }
278}
279
280#[tokio::test]
281async fn test_mutex_table_concurrent_in_same_bucket() {
284 use tokio::time::{sleep, timeout};
285 let mutex_table = Arc::new(MutexTable::<String>::new(1));
286 let john = mutex_table.try_acquire_lock("john".to_string());
287 let _ = john.unwrap();
288 {
289 let mutex_table = mutex_table.clone();
290 std::thread::spawn(move || {
291 let _ = mutex_table.acquire_lock("john".to_string());
292 });
293 }
294 sleep(Duration::from_millis(50)).await;
295 let jane = mutex_table.try_acquire_lock("jane".to_string());
296 let _ = jane.unwrap();
297
298 let mutex_table = Arc::new(MutexTable::<String>::new(1));
299 let _john = mutex_table.acquire_lock("john".to_string());
300 {
301 let mutex_table = mutex_table.clone();
302 std::thread::spawn(move || {
303 let _ = mutex_table.acquire_lock("john".to_string());
304 });
305 }
306 sleep(Duration::from_millis(50)).await;
307 let jane = timeout(
308 Duration::from_secs(1),
309 tokio::task::spawn_blocking(move || {
310 let _ = mutex_table.acquire_lock("jane".to_string());
311 }),
312 )
313 .await;
314 let _ = jane.unwrap();
315}
316
317#[tokio::test]
318async fn test_mutex_table() {
319 let mutex_table =
321 MutexTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
322 let john1 = mutex_table.try_acquire_lock("john".to_string());
323 assert!(john1.is_ok());
324 let john2 = mutex_table.try_acquire_lock("john".to_string());
325 assert!(john2.is_err());
326 drop(john1);
327 let john2 = mutex_table.try_acquire_lock("john".to_string());
328 assert!(john2.is_ok());
329 let jane = mutex_table.try_acquire_lock("jane".to_string());
330 assert!(jane.is_ok());
331 MutexTable::cleanup(mutex_table.lock_table.clone());
332 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
333 assert!(map.is_some());
334 assert_eq!(map.unwrap().len(), 2);
335 drop(john2);
336 MutexTable::cleanup(mutex_table.lock_table.clone());
337 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
338 assert!(map.is_some());
339 assert_eq!(map.unwrap().len(), 1);
340 drop(jane);
341 MutexTable::cleanup(mutex_table.lock_table.clone());
342 let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
343 assert!(map.is_some());
344 assert!(map.unwrap().is_empty());
345}
346
347#[tokio::test]
348async fn test_acquire_locks() {
349 let mutex_table =
350 RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
351 let object_1 = "object 1".to_string();
352 let object_2 = "object 2".to_string();
353 let object_3 = "object 3".to_string();
354
355 let objects = vec![
357 object_1.clone(),
358 object_2.clone(),
359 object_2,
360 object_1.clone(),
361 object_3,
362 object_1,
363 ];
364
365 let locks = mutex_table.acquire_locks(objects.clone().into_iter());
366 assert_eq!(locks.len(), 3);
367
368 for object in objects.clone() {
369 assert!(mutex_table.try_acquire_lock(object).is_err());
370 }
371
372 drop(locks);
373 let locks = mutex_table.acquire_locks(objects.into_iter());
374 assert_eq!(locks.len(), 3);
375}
376
377#[tokio::test]
378async fn test_read_locks() {
379 let mutex_table =
380 RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
381 let lock = "lock".to_string();
382 let locks1 = mutex_table.acquire_read_locks(vec![lock.clone()]);
383 assert!(mutex_table.try_acquire_lock(lock.clone()).is_err());
384 let locks2 = mutex_table.acquire_read_locks(vec![lock.clone()]);
385 drop(locks1);
386 drop(locks2);
387 assert!(mutex_table.try_acquire_lock(lock.clone()).is_ok());
388}
389
390#[tokio::test(flavor = "current_thread", start_paused = true)]
391async fn test_mutex_table_bg_cleanup() {
392 let mutex_table = MutexTable::<String>::new_with_cleanup(
393 1,
394 Duration::from_secs(5),
395 Duration::from_secs(1),
396 1000,
397 );
398 let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
399 let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
400 let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
401 let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
402 let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
403 assert!(lock1.is_ok());
404 assert!(lock2.is_ok());
405 assert!(lock3.is_ok());
406 assert!(lock4.is_ok());
407 assert!(lock5.is_ok());
408 MutexTable::cleanup(mutex_table.lock_table.clone());
410 let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
412 let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
413 let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
414 let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
415 let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
416 assert!(lock11.is_err());
417 assert!(lock22.is_err());
418 assert!(lock33.is_err());
419 assert!(lock44.is_err());
420 assert!(lock55.is_err());
421 drop(lock1);
423 drop(lock2);
424 drop(lock3);
425 drop(lock4);
426 drop(lock5);
427 tokio::time::sleep(Duration::from_secs(10)).await;
429 for entry in mutex_table.lock_table.iter() {
430 let locked = entry.read();
431 assert!(locked.is_empty());
432 }
433}
434
435#[tokio::test(flavor = "current_thread", start_paused = true)]
436async fn test_mutex_table_bg_cleanup_with_size_threshold() {
437 let mutex_table =
439 MutexTable::<String>::new_with_cleanup(1, Duration::MAX, Duration::from_secs(1), 5);
440 let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
441 let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
442 let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
443 let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
444 let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
445 assert!(lock1.is_ok());
446 assert!(lock2.is_ok());
447 assert!(lock3.is_ok());
448 assert!(lock4.is_ok());
449 assert!(lock5.is_ok());
450 MutexTable::cleanup(mutex_table.lock_table.clone());
452 let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
454 let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
455 let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
456 let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
457 let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
458 assert!(lock11.is_err());
459 assert!(lock22.is_err());
460 assert!(lock33.is_err());
461 assert!(lock44.is_err());
462 assert!(lock55.is_err());
463 assert_eq!(mutex_table.size(), 5);
464 drop(lock1);
466 drop(lock2);
467 drop(lock3);
468 drop(lock4);
469 drop(lock5);
470 tokio::task::yield_now().await;
471 tokio::time::advance(Duration::from_secs(5)).await;
473 tokio::task::yield_now().await;
474 assert_eq!(mutex_table.size(), 0);
475 for entry in mutex_table.lock_table.iter() {
476 let locked = entry.read();
477 assert!(locked.is_empty());
478 }
479}