sui_storage/
sharded_lru.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright (c) Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use std::{
    collections::hash_map::RandomState,
    hash::{BuildHasher, Hash},
};

use lru::LruCache;
use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
use std::fmt::Debug;
use std::num::NonZeroUsize;

pub struct ShardedLruCache<K, V, S = RandomState> {
    shards: Vec<RwLock<LruCache<K, V>>>,
    hasher: S,
}

unsafe impl<K, V, S> Send for ShardedLruCache<K, V, S> {}
unsafe impl<K, V, S> Sync for ShardedLruCache<K, V, S> {}

impl<K, V> ShardedLruCache<K, V, RandomState>
where
    K: Send + Sync + Hash + Eq + Clone,
    V: Send + Sync + Clone,
{
    pub fn new(capacity: u64, num_shards: u64) -> Self {
        let cap_per_shard = capacity.div_ceil(num_shards);
        let hasher = RandomState::default();
        Self {
            hasher,
            shards: (0..num_shards)
                .map(|_| {
                    RwLock::new(LruCache::new(
                        NonZeroUsize::new(cap_per_shard as usize).unwrap(),
                    ))
                })
                .collect(),
        }
    }
}

impl<K, V, S> ShardedLruCache<K, V, S>
where
    K: Hash + Eq + Clone + Debug,
    V: Clone,
    S: BuildHasher,
{
    fn shard_id(&self, key: &K) -> usize {
        let h = self.hasher.hash_one(key) as usize;
        h % self.shards.len()
    }

    fn read_shard(&self, key: &K) -> RwLockReadGuard<'_, LruCache<K, V>> {
        let shard_idx = self.shard_id(key);
        self.shards[shard_idx].read()
    }

    fn write_shard(&self, key: &K) -> RwLockWriteGuard<'_, LruCache<K, V>> {
        let shard_idx = self.shard_id(key);
        self.shards[shard_idx].write()
    }

    pub fn invalidate(&self, key: &K) -> Option<V> {
        self.write_shard(key).pop(key)
    }

    pub fn batch_invalidate(&self, keys: impl IntoIterator<Item = K>) {
        let mut grouped = HashMap::new();
        for key in keys.into_iter() {
            let shard_idx = self.shard_id(&key);
            grouped.entry(shard_idx).or_insert(vec![]).push(key);
        }
        for (shard_idx, keys) in grouped.into_iter() {
            let mut lock = self.shards[shard_idx].write();
            for key in keys {
                lock.pop(&key);
            }
        }
    }

    pub fn merge(&self, key: K, value: &V, f: fn(&V, &V) -> V) {
        let mut shard = self.write_shard(&key);
        let old_value = shard.get(&key);
        if let Some(old_value) = old_value {
            let new_value = f(old_value, value);
            shard.put(key, new_value);
        }
    }

    pub fn batch_merge(&self, key_values: impl IntoIterator<Item = (K, V)>, f: fn(&V, &V) -> V) {
        let mut grouped = HashMap::new();
        for (key, value) in key_values.into_iter() {
            let shard_idx = self.shard_id(&key);
            grouped
                .entry(shard_idx)
                .or_insert(vec![])
                .push((key, value));
        }
        for (shard_idx, keys) in grouped.into_iter() {
            let mut shard = self.shards[shard_idx].write();
            for (key, value) in keys.into_iter() {
                let old_value = shard.get(&key);
                if let Some(old_value) = old_value {
                    let new_value = f(old_value, &value);
                    shard.put(key, new_value);
                }
            }
        }
    }

    pub fn get(&self, key: &K) -> Option<V> {
        self.read_shard(key).peek(key).cloned()
    }

    pub fn get_with(&self, key: K, init: impl FnOnce() -> V) -> V {
        let shard = self.read_shard(&key);
        let value = shard.peek(&key);
        if let Some(value) = value {
            return value.clone();
        }
        drop(shard);
        let mut shard = self.write_shard(&key);
        let value = shard.get(&key);
        if let Some(value) = value {
            return value.clone();
        }
        let value = init();
        let cloned_value = value.clone();
        shard.push(key, value);
        cloned_value
    }
}