grafos_cache/
elastic.rs

1//! Elastic sharded hash map backed by fabric memory.
2//!
3//! [`ElasticShardSet`] distributes key-value pairs across multiple shards,
4//! each backed by its own [`grafos_std::mem::MemLease`] +
5//! [`FabricHashMap`]. Shards can be
6//! added or removed at runtime, and [`rehash`](ElasticShardSet::rehash)
7//! redistributes entries across the current shard set.
8//!
9//! Shard selection uses `hash(key) % shard_count` (FNV-1a after postcard
10//! serialization, matching [`FabricHashMap`]'s internal hash function).
11
12extern crate alloc;
13use alloc::vec::Vec;
14
15use grafos_collections::map::FabricHashMap;
16use grafos_std::error::{FabricError, Result};
17
18use serde::{de::DeserializeOwned, Serialize};
19
20/// A set of shards that can grow and shrink at runtime.
21///
22/// Each shard is a [`FabricHashMap`] backed by its own
23/// [`grafos_std::mem::MemLease`].
24/// Keys are routed to shards via `hash(key) % shard_count`.
25///
26/// # Example
27///
28/// ```rust
29/// use grafos_cache::elastic::ElasticShardSet;
30///
31/// # grafos_std::host::reset_mock();
32/// # grafos_std::host::mock_set_fbmu_arena_size(65536);
33/// let mut shards: ElasticShardSet<u32, u32> =
34///     ElasticShardSet::new(2, 64, 8, 8)?;
35///
36/// shards.put(&1, &100)?;
37/// shards.put(&2, &200)?;
38/// assert_eq!(shards.get(&1)?, Some(100));
39/// assert_eq!(shards.shard_count(), 2);
40///
41/// // Grow by one shard and redistribute
42/// shards.add_shard()?;
43/// shards.rehash()?;
44/// assert_eq!(shards.shard_count(), 3);
45/// assert_eq!(shards.get(&1)?, Some(100));
46/// # Ok::<(), grafos_std::FabricError>(())
47/// ```
48pub struct ElasticShardSet<K, V> {
49    shards: Vec<FabricHashMap<K, V>>,
50    buckets_per_shard: usize,
51    key_stride: usize,
52    val_stride: usize,
53}
54
55impl<K, V> ElasticShardSet<K, V>
56where
57    K: Serialize + DeserializeOwned + PartialEq + Clone,
58    V: Serialize + DeserializeOwned + Clone,
59{
60    /// Create a new shard set with `initial_shards` shards.
61    ///
62    /// Each shard is a [`FabricHashMap`] allocated via
63    /// [`grafos_std::mem::MemBuilder`] with
64    /// `buckets_per_shard` buckets. `key_stride` and `val_stride` control the
65    /// maximum serialized size of keys and values respectively.
66    ///
67    /// # Errors
68    ///
69    /// Returns [`FabricError::CapacityExceeded`] if the host cannot provide
70    /// enough memory for the requested shards.
71    pub fn new(
72        initial_shards: usize,
73        buckets_per_shard: usize,
74        key_stride: usize,
75        val_stride: usize,
76    ) -> Result<Self> {
77        if initial_shards == 0 {
78            return Err(FabricError::CapacityExceeded);
79        }
80        let mut shards = Vec::with_capacity(initial_shards);
81        for _ in 0..initial_shards {
82            let map = FabricHashMap::with_capacity(buckets_per_shard, key_stride, val_stride)?;
83            shards.push(map);
84        }
85        Ok(ElasticShardSet {
86            shards,
87            buckets_per_shard,
88            key_stride,
89            val_stride,
90        })
91    }
92
93    /// Look up a value by key.
94    ///
95    /// Hashes the key to select the correct shard, then performs a lookup
96    /// within that shard.
97    pub fn get(&self, key: &K) -> Result<Option<V>> {
98        let idx = self.shard_index(key)?;
99        self.shards[idx].get(key)
100    }
101
102    /// Insert or update a key-value pair.
103    ///
104    /// Hashes the key to select the correct shard, then inserts into that
105    /// shard.
106    pub fn put(&mut self, key: &K, value: &V) -> Result<()> {
107        let idx = self.shard_index(key)?;
108        self.shards[idx].insert(key, value)?;
109        Ok(())
110    }
111
112    /// Remove a key-value pair. Returns `true` if the key existed.
113    pub fn remove(&mut self, key: &K) -> Result<bool> {
114        let idx = self.shard_index(key)?;
115        Ok(self.shards[idx].remove(key)?.is_some())
116    }
117
118    /// Add one shard to the set.
119    ///
120    /// The new shard is empty. Call [`rehash`](Self::rehash) afterwards to
121    /// redistribute existing entries across all shards.
122    pub fn add_shard(&mut self) -> Result<()> {
123        let map =
124            FabricHashMap::with_capacity(self.buckets_per_shard, self.key_stride, self.val_stride)?;
125        self.shards.push(map);
126        Ok(())
127    }
128
129    /// Remove the last shard from the set.
130    ///
131    /// All entries across every shard are drained and redistributed among
132    /// the remaining shards (since `hash % shard_count` changes for all
133    /// keys when the shard count changes). The removed shard's lease is
134    /// released on drop.
135    ///
136    /// # Errors
137    ///
138    /// Returns [`FabricError::CapacityExceeded`] if the set has only one
139    /// shard (cannot shrink below 1).
140    pub fn remove_shard(&mut self) -> Result<()> {
141        if self.shards.len() <= 1 {
142            return Err(FabricError::CapacityExceeded);
143        }
144        // Drain all entries from every shard (routing changes for all keys)
145        let mut all_entries: Vec<(K, V)> = Vec::new();
146        for shard in &self.shards {
147            for entry in shard.iter() {
148                all_entries.push(entry?);
149            }
150        }
151        // Drop the last shard
152        self.shards.pop();
153        // Replace remaining shards with fresh empty ones
154        let new_count = self.shards.len();
155        self.shards.clear();
156        for _ in 0..new_count {
157            let map = FabricHashMap::with_capacity(
158                self.buckets_per_shard,
159                self.key_stride,
160                self.val_stride,
161            )?;
162            self.shards.push(map);
163        }
164        // Re-insert all entries with the new shard count
165        for (k, v) in &all_entries {
166            let idx = self.shard_index(k)?;
167            self.shards[idx].insert(k, v)?;
168        }
169        Ok(())
170    }
171
172    /// Redistribute all entries across the current shard set.
173    ///
174    /// Drains every shard, then re-inserts all entries using the current
175    /// shard count. This is useful after [`add_shard`](Self::add_shard) to
176    /// balance load.
177    pub fn rehash(&mut self) -> Result<()> {
178        // Collect all entries from all shards
179        let mut all_entries: Vec<(K, V)> = Vec::new();
180        for shard in &self.shards {
181            for entry in shard.iter() {
182                all_entries.push(entry?);
183            }
184        }
185
186        // Replace all shards with fresh empty ones
187        let shard_count = self.shards.len();
188        self.shards.clear();
189        for _ in 0..shard_count {
190            let map = FabricHashMap::with_capacity(
191                self.buckets_per_shard,
192                self.key_stride,
193                self.val_stride,
194            )?;
195            self.shards.push(map);
196        }
197
198        // Re-insert all entries
199        for (k, v) in &all_entries {
200            let idx = self.shard_index(k)?;
201            self.shards[idx].insert(k, v)?;
202        }
203        Ok(())
204    }
205
206    /// Number of shards in the set.
207    pub fn shard_count(&self) -> usize {
208        self.shards.len()
209    }
210
211    /// Total number of entries across all shards.
212    pub fn len(&self) -> Result<usize> {
213        let mut total = 0;
214        for shard in &self.shards {
215            total += shard.len();
216        }
217        Ok(total)
218    }
219
220    /// Returns `true` if no entries exist across all shards.
221    pub fn is_empty(&self) -> Result<bool> {
222        Ok(self.len()? == 0)
223    }
224
225    /// Compute the shard index for a given key using FNV-1a.
226    fn shard_index(&self, key: &K) -> Result<usize> {
227        let key_bytes = postcard::to_allocvec(key).map_err(|_| FabricError::IoError(-1))?;
228        let hash = fnv1a(&key_bytes);
229        Ok((hash as usize) % self.shards.len())
230    }
231}
232
233/// FNV-1a hash (matches FabricHashMap's internal hash).
234fn fnv1a(data: &[u8]) -> u64 {
235    let mut hash: u64 = 0xcbf29ce484222325;
236    for &b in data {
237        hash ^= b as u64;
238        hash = hash.wrapping_mul(0x100000001b3);
239    }
240    hash
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use grafos_std::host;
247
248    fn setup() {
249        host::reset_mock();
250        host::mock_set_fbmu_arena_size(65536);
251    }
252
253    #[test]
254    fn create_put_get_remove() {
255        setup();
256        let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(2, 64, 8, 8).expect("new");
257
258        assert!(shards.is_empty().unwrap());
259        shards.put(&1, &100).expect("put 1");
260        shards.put(&2, &200).expect("put 2");
261        shards.put(&3, &300).expect("put 3");
262
263        assert_eq!(shards.len().unwrap(), 3);
264        assert_eq!(shards.get(&1).unwrap(), Some(100));
265        assert_eq!(shards.get(&2).unwrap(), Some(200));
266        assert_eq!(shards.get(&3).unwrap(), Some(300));
267        assert_eq!(shards.get(&99).unwrap(), None);
268
269        assert!(shards.remove(&2).unwrap());
270        assert!(!shards.remove(&99).unwrap());
271        assert_eq!(shards.len().unwrap(), 2);
272        assert_eq!(shards.get(&2).unwrap(), None);
273    }
274
275    #[test]
276    fn correct_shard_routing() {
277        setup();
278        let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(4, 64, 8, 8).expect("new");
279
280        // Insert many keys and verify they all round-trip
281        for i in 0..20u32 {
282            shards.put(&i, &(i * 10)).expect("put");
283        }
284        assert_eq!(shards.len().unwrap(), 20);
285
286        for i in 0..20u32 {
287            assert_eq!(shards.get(&i).unwrap(), Some(i * 10), "key {i}");
288        }
289    }
290
291    #[test]
292    fn add_shard_entries_still_accessible() {
293        setup();
294        let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(2, 64, 8, 8).expect("new");
295
296        shards.put(&10, &1000).expect("put");
297        shards.put(&20, &2000).expect("put");
298        shards.put(&30, &3000).expect("put");
299
300        shards.add_shard().expect("add_shard");
301        assert_eq!(shards.shard_count(), 3);
302
303        // Before rehash, entries are still in their old shards
304        // After rehash, they should be redistributed but still accessible
305        shards.rehash().expect("rehash");
306        assert_eq!(shards.len().unwrap(), 3);
307        assert_eq!(shards.get(&10).unwrap(), Some(1000));
308        assert_eq!(shards.get(&20).unwrap(), Some(2000));
309        assert_eq!(shards.get(&30).unwrap(), Some(3000));
310    }
311
312    #[test]
313    fn remove_shard_entries_migrated() {
314        setup();
315        let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(3, 64, 8, 8).expect("new");
316
317        for i in 0..10u32 {
318            shards.put(&i, &(i * 100)).expect("put");
319        }
320        assert_eq!(shards.shard_count(), 3);
321        assert_eq!(shards.len().unwrap(), 10);
322
323        shards.remove_shard().expect("remove_shard");
324        assert_eq!(shards.shard_count(), 2);
325        assert_eq!(shards.len().unwrap(), 10);
326
327        for i in 0..10u32 {
328            assert_eq!(shards.get(&i).unwrap(), Some(i * 100), "key {i}");
329        }
330    }
331
332    #[test]
333    fn remove_shard_below_one_fails() {
334        setup();
335        let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(1, 64, 8, 8).expect("new");
336
337        assert_eq!(
338            shards.remove_shard().unwrap_err(),
339            FabricError::CapacityExceeded
340        );
341    }
342
343    #[test]
344    fn rehash_redistributes() {
345        setup();
346        let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(2, 64, 8, 8).expect("new");
347
348        for i in 0..8u32 {
349            shards.put(&i, &(i + 1)).expect("put");
350        }
351
352        // Add two more shards and rehash
353        shards.add_shard().expect("add");
354        shards.add_shard().expect("add");
355        assert_eq!(shards.shard_count(), 4);
356
357        shards.rehash().expect("rehash");
358        assert_eq!(shards.len().unwrap(), 8);
359
360        // All entries still accessible
361        for i in 0..8u32 {
362            assert_eq!(shards.get(&i).unwrap(), Some(i + 1), "key {i}");
363        }
364    }
365
366    #[test]
367    fn zero_initial_shards_fails() {
368        setup();
369        let result: Result<ElasticShardSet<u32, u32>> = ElasticShardSet::new(0, 64, 8, 8);
370        match result {
371            Err(FabricError::CapacityExceeded) => {}
372            Err(e) => panic!("expected CapacityExceeded, got {e:?}"),
373            Ok(_) => panic!("expected error, got Ok"),
374        }
375    }
376}