1use std::collections::HashMap;
5use std::collections::hash_map::{DefaultHasher, RandomState};
6use std::error::Error;
7use std::fmt;
8use std::hash::{BuildHasher, Hash, Hasher};
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use std::time::Duration;
12
13use parking_lot::{ArcMutexGuard, ArcRwLockReadGuard, ArcRwLockWriteGuard, Mutex, RwLock};
14use tokio::task::JoinHandle;
15use tokio::time::Instant;
16use tracing::info;
17
18use mysten_metrics::spawn_monitored_task;
19
20type OwnedMutexGuard<T> = ArcMutexGuard<parking_lot::RawMutex, T>;
21type OwnedRwLockReadGuard<T> = ArcRwLockReadGuard<parking_lot::RawRwLock, T>;
22type OwnedRwLockWriteGuard<T> = ArcRwLockWriteGuard<parking_lot::RawRwLock, T>;
23
24pub trait Lock: Send + Sync + Default {
25    type Guard;
26    type ReadGuard;
27    fn lock_owned(self: Arc<Self>) -> Self::Guard;
28    fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard>;
29    fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard;
30}
31
32impl Lock for Mutex<()> {
33    type Guard = OwnedMutexGuard<()>;
34    type ReadGuard = Self::Guard;
35
36    fn lock_owned(self: Arc<Self>) -> Self::Guard {
37        self.lock_arc()
38    }
39
40    fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
41        self.try_lock_arc()
42    }
43
44    fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
45        self.lock_arc()
46    }
47}
48
49impl Lock for RwLock<()> {
50    type Guard = OwnedRwLockWriteGuard<()>;
51    type ReadGuard = OwnedRwLockReadGuard<()>;
52
53    fn lock_owned(self: Arc<Self>) -> Self::Guard {
54        self.write_arc()
55    }
56
57    fn try_lock_owned(self: Arc<Self>) -> Option<Self::Guard> {
58        self.try_write_arc()
59    }
60
61    fn read_lock_owned(self: Arc<Self>) -> Self::ReadGuard {
62        self.read_arc()
63    }
64}
65
66type InnerLockTable<K, L> = HashMap<K, Arc<L>>;
67pub struct LockTable<K: Hash, L: Lock> {
69    random_state: RandomState,
70    lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>,
71    _k: std::marker::PhantomData<K>,
72    _cleaner: JoinHandle<()>,
73    stop: Arc<AtomicBool>,
74    size: Arc<AtomicUsize>,
75}
76
77pub type MutexTable<K> = LockTable<K, Mutex<()>>;
78pub type RwLockTable<K> = LockTable<K, RwLock<()>>;
79
80#[derive(Debug)]
81pub enum TryAcquireLockError {
82    LockTableLocked,
83    LockEntryLocked,
84}
85
86impl fmt::Display for TryAcquireLockError {
87    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
88        write!(fmt, "operation would block")
89    }
90}
91
92impl Error for TryAcquireLockError {}
93pub type MutexGuard = OwnedMutexGuard<()>;
94pub type RwLockGuard = OwnedRwLockReadGuard<()>;
95
96impl<K: Hash + Eq + Send + Sync + 'static, L: Lock + 'static> LockTable<K, L> {
97    pub fn new_with_cleanup(
98        num_shards: usize,
99        cleanup_period: Duration,
100        cleanup_initial_delay: Duration,
101        cleanup_entries_threshold: usize,
102    ) -> Self {
103        let num_shards = if cfg!(msim) { 4 } else { num_shards };
104
105        let lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>> = Arc::new(
106            (0..num_shards)
107                .map(|_| RwLock::new(HashMap::new()))
108                .collect(),
109        );
110        let cloned = lock_table.clone();
111        let stop = Arc::new(AtomicBool::new(false));
112        let stop_cloned = stop.clone();
113        let size: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
114        let size_cloned = size.clone();
115        Self {
116            random_state: RandomState::new(),
117            lock_table,
118            _k: std::marker::PhantomData {},
119            _cleaner: spawn_monitored_task!(async move {
120                tokio::time::sleep(cleanup_initial_delay).await;
121                let mut previous_cleanup_instant = Instant::now();
122                while !stop_cloned.load(Ordering::SeqCst) {
123                    if size_cloned.load(Ordering::SeqCst) >= cleanup_entries_threshold
124                        || previous_cleanup_instant.elapsed() >= cleanup_period
125                    {
126                        let num_removed = Self::cleanup(cloned.clone());
127                        size_cloned.fetch_sub(num_removed, Ordering::SeqCst);
128                        previous_cleanup_instant = Instant::now();
129                    }
130                    tokio::time::sleep(Duration::from_secs(1)).await;
131                }
132                info!("Stopping mutex table cleanup!");
133            }),
134            stop,
135            size,
136        }
137    }
138
139    pub fn new(num_shards: usize) -> Self {
140        Self::new_with_cleanup(
141            num_shards,
142            Duration::from_secs(10),
143            Duration::from_secs(10),
144            10_000,
145        )
146    }
147
148    pub fn size(&self) -> usize {
149        self.size.load(Ordering::SeqCst)
150    }
151
152    pub fn cleanup(lock_table: Arc<Vec<RwLock<InnerLockTable<K, L>>>>) -> usize {
153        let mut num_removed: usize = 0;
154        for shard in lock_table.iter() {
155            let map = shard.try_write();
156            if map.is_none() {
157                continue;
158            }
159            map.unwrap().retain(|_k, v| {
160                if Arc::strong_count(v) == 1 {
164                    num_removed += 1;
165                    false
166                } else {
167                    true
168                }
169            });
170        }
171        num_removed
172    }
173
174    fn get_lock_idx(&self, key: &K) -> usize {
175        let mut hasher = if !cfg!(test) {
176            self.random_state.build_hasher()
177        } else {
178            DefaultHasher::new()
180        };
181
182        key.hash(&mut hasher);
183        let hash: usize = hasher.finish().try_into().unwrap();
185        hash % self.lock_table.len()
186    }
187
188    pub fn acquire_locks<I>(&self, object_iter: I) -> Vec<L::Guard>
189    where
190        I: Iterator<Item = K>,
191        K: Ord,
192    {
193        let mut objects: Vec<K> = object_iter.into_iter().collect();
194        objects.sort_unstable();
195        objects.dedup();
196
197        let mut guards = Vec::with_capacity(objects.len());
198        for object in objects.into_iter() {
199            guards.push(self.acquire_lock(object));
200        }
201        guards
202    }
203
204    pub fn acquire_read_locks(&self, mut objects: Vec<K>) -> Vec<L::ReadGuard>
205    where
206        K: Ord,
207    {
208        objects.sort_unstable();
209        objects.dedup();
210        let mut guards = Vec::with_capacity(objects.len());
211        for object in objects.into_iter() {
212            guards.push(self.get_lock(object).read_lock_owned());
213        }
214        guards
215    }
216
217    pub fn get_lock(&self, k: K) -> Arc<L> {
218        let lock_idx = self.get_lock_idx(&k);
219        let element = {
220            let map = self.lock_table[lock_idx].read();
221            map.get(&k).cloned()
222        };
223        if let Some(element) = element {
224            element
225        } else {
226            {
229                let mut map = self.lock_table[lock_idx].write();
230                map.entry(k)
231                    .or_insert_with(|| {
232                        self.size.fetch_add(1, Ordering::SeqCst);
233                        Arc::new(L::default())
234                    })
235                    .clone()
236            }
237        }
238    }
239
240    pub fn acquire_lock(&self, k: K) -> L::Guard {
241        self.get_lock(k).lock_owned()
242    }
243
244    pub fn try_acquire_lock(&self, k: K) -> Result<L::Guard, TryAcquireLockError> {
245        let lock_idx = self.get_lock_idx(&k);
246        let element = {
247            let map = self.lock_table[lock_idx]
248                .try_read()
249                .ok_or(TryAcquireLockError::LockTableLocked)?;
250            map.get(&k).cloned()
251        };
252        if let Some(element) = element {
253            let lock = element.try_lock_owned();
254            lock.ok_or(TryAcquireLockError::LockEntryLocked)
255        } else {
256            let element = {
258                let mut map = self.lock_table[lock_idx]
259                    .try_write()
260                    .ok_or(TryAcquireLockError::LockTableLocked)?;
261                map.entry(k)
262                    .or_insert_with(|| {
263                        self.size.fetch_add(1, Ordering::SeqCst);
264                        Arc::new(L::default())
265                    })
266                    .clone()
267            };
268            let lock = element.try_lock_owned();
269            lock.ok_or(TryAcquireLockError::LockEntryLocked)
270        }
271    }
272}
273
274impl<K: Hash, L: Lock> Drop for LockTable<K, L> {
275    fn drop(&mut self) {
276        self.stop.store(true, Ordering::SeqCst);
277    }
278}
279
280#[tokio::test]
281async fn test_mutex_table_concurrent_in_same_bucket() {
284    use tokio::time::{sleep, timeout};
285    let mutex_table = Arc::new(MutexTable::<String>::new(1));
286    let john = mutex_table.try_acquire_lock("john".to_string());
287    let _ = john.unwrap();
288    {
289        let mutex_table = mutex_table.clone();
290        std::thread::spawn(move || {
291            let _ = mutex_table.acquire_lock("john".to_string());
292        });
293    }
294    sleep(Duration::from_millis(50)).await;
295    let jane = mutex_table.try_acquire_lock("jane".to_string());
296    let _ = jane.unwrap();
297
298    let mutex_table = Arc::new(MutexTable::<String>::new(1));
299    let _john = mutex_table.acquire_lock("john".to_string());
300    {
301        let mutex_table = mutex_table.clone();
302        std::thread::spawn(move || {
303            let _ = mutex_table.acquire_lock("john".to_string());
304        });
305    }
306    sleep(Duration::from_millis(50)).await;
307    let jane = timeout(
308        Duration::from_secs(1),
309        tokio::task::spawn_blocking(move || {
310            let _ = mutex_table.acquire_lock("jane".to_string());
311        }),
312    )
313    .await;
314    let _ = jane.unwrap();
315}
316
317#[tokio::test]
318async fn test_mutex_table() {
319    let mutex_table =
321        MutexTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
322    let john1 = mutex_table.try_acquire_lock("john".to_string());
323    assert!(john1.is_ok());
324    let john2 = mutex_table.try_acquire_lock("john".to_string());
325    assert!(john2.is_err());
326    drop(john1);
327    let john2 = mutex_table.try_acquire_lock("john".to_string());
328    assert!(john2.is_ok());
329    let jane = mutex_table.try_acquire_lock("jane".to_string());
330    assert!(jane.is_ok());
331    MutexTable::cleanup(mutex_table.lock_table.clone());
332    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
333    assert!(map.is_some());
334    assert_eq!(map.unwrap().len(), 2);
335    drop(john2);
336    MutexTable::cleanup(mutex_table.lock_table.clone());
337    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
338    assert!(map.is_some());
339    assert_eq!(map.unwrap().len(), 1);
340    drop(jane);
341    MutexTable::cleanup(mutex_table.lock_table.clone());
342    let map = mutex_table.lock_table.first().as_ref().unwrap().try_read();
343    assert!(map.is_some());
344    assert!(map.unwrap().is_empty());
345}
346
347#[tokio::test]
348async fn test_acquire_locks() {
349    let mutex_table =
350        RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
351    let object_1 = "object 1".to_string();
352    let object_2 = "object 2".to_string();
353    let object_3 = "object 3".to_string();
354
355    let objects = vec![
357        object_1.clone(),
358        object_2.clone(),
359        object_2,
360        object_1.clone(),
361        object_3,
362        object_1,
363    ];
364
365    let locks = mutex_table.acquire_locks(objects.clone().into_iter());
366    assert_eq!(locks.len(), 3);
367
368    for object in objects.clone() {
369        assert!(mutex_table.try_acquire_lock(object).is_err());
370    }
371
372    drop(locks);
373    let locks = mutex_table.acquire_locks(objects.into_iter());
374    assert_eq!(locks.len(), 3);
375}
376
377#[tokio::test]
378async fn test_read_locks() {
379    let mutex_table =
380        RwLockTable::<String>::new_with_cleanup(1, Duration::from_secs(10), Duration::MAX, 1000);
381    let lock = "lock".to_string();
382    let locks1 = mutex_table.acquire_read_locks(vec![lock.clone()]);
383    assert!(mutex_table.try_acquire_lock(lock.clone()).is_err());
384    let locks2 = mutex_table.acquire_read_locks(vec![lock.clone()]);
385    drop(locks1);
386    drop(locks2);
387    assert!(mutex_table.try_acquire_lock(lock.clone()).is_ok());
388}
389
390#[tokio::test(flavor = "current_thread", start_paused = true)]
391async fn test_mutex_table_bg_cleanup() {
392    let mutex_table = MutexTable::<String>::new_with_cleanup(
393        1,
394        Duration::from_secs(5),
395        Duration::from_secs(1),
396        1000,
397    );
398    let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
399    let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
400    let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
401    let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
402    let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
403    assert!(lock1.is_ok());
404    assert!(lock2.is_ok());
405    assert!(lock3.is_ok());
406    assert!(lock4.is_ok());
407    assert!(lock5.is_ok());
408    MutexTable::cleanup(mutex_table.lock_table.clone());
410    let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
412    let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
413    let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
414    let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
415    let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
416    assert!(lock11.is_err());
417    assert!(lock22.is_err());
418    assert!(lock33.is_err());
419    assert!(lock44.is_err());
420    assert!(lock55.is_err());
421    drop(lock1);
423    drop(lock2);
424    drop(lock3);
425    drop(lock4);
426    drop(lock5);
427    tokio::time::sleep(Duration::from_secs(10)).await;
429    for entry in mutex_table.lock_table.iter() {
430        let locked = entry.read();
431        assert!(locked.is_empty());
432    }
433}
434
435#[tokio::test(flavor = "current_thread", start_paused = true)]
436async fn test_mutex_table_bg_cleanup_with_size_threshold() {
437    let mutex_table =
439        MutexTable::<String>::new_with_cleanup(1, Duration::MAX, Duration::from_secs(1), 5);
440    let lock1 = mutex_table.try_acquire_lock("lock1".to_string());
441    let lock2 = mutex_table.try_acquire_lock("lock2".to_string());
442    let lock3 = mutex_table.try_acquire_lock("lock3".to_string());
443    let lock4 = mutex_table.try_acquire_lock("lock4".to_string());
444    let lock5 = mutex_table.try_acquire_lock("lock5".to_string());
445    assert!(lock1.is_ok());
446    assert!(lock2.is_ok());
447    assert!(lock3.is_ok());
448    assert!(lock4.is_ok());
449    assert!(lock5.is_ok());
450    MutexTable::cleanup(mutex_table.lock_table.clone());
452    let lock11 = mutex_table.try_acquire_lock("lock1".to_string());
454    let lock22 = mutex_table.try_acquire_lock("lock2".to_string());
455    let lock33 = mutex_table.try_acquire_lock("lock3".to_string());
456    let lock44 = mutex_table.try_acquire_lock("lock4".to_string());
457    let lock55 = mutex_table.try_acquire_lock("lock5".to_string());
458    assert!(lock11.is_err());
459    assert!(lock22.is_err());
460    assert!(lock33.is_err());
461    assert!(lock44.is_err());
462    assert!(lock55.is_err());
463    assert_eq!(mutex_table.size(), 5);
464    drop(lock1);
466    drop(lock2);
467    drop(lock3);
468    drop(lock4);
469    drop(lock5);
470    tokio::task::yield_now().await;
471    tokio::time::advance(Duration::from_secs(5)).await;
473    tokio::task::yield_now().await;
474    assert_eq!(mutex_table.size(), 0);
475    for entry in mutex_table.lock_table.iter() {
476        let locked = entry.read();
477        assert!(locked.is_empty());
478    }
479}