1extern crate alloc;
13use alloc::vec::Vec;
14
15use grafos_collections::map::FabricHashMap;
16use grafos_std::error::{FabricError, Result};
17
18use serde::{de::DeserializeOwned, Serialize};
19
20pub struct ElasticShardSet<K, V> {
49 shards: Vec<FabricHashMap<K, V>>,
50 buckets_per_shard: usize,
51 key_stride: usize,
52 val_stride: usize,
53}
54
55impl<K, V> ElasticShardSet<K, V>
56where
57 K: Serialize + DeserializeOwned + PartialEq + Clone,
58 V: Serialize + DeserializeOwned + Clone,
59{
60 pub fn new(
72 initial_shards: usize,
73 buckets_per_shard: usize,
74 key_stride: usize,
75 val_stride: usize,
76 ) -> Result<Self> {
77 if initial_shards == 0 {
78 return Err(FabricError::CapacityExceeded);
79 }
80 let mut shards = Vec::with_capacity(initial_shards);
81 for _ in 0..initial_shards {
82 let map = FabricHashMap::with_capacity(buckets_per_shard, key_stride, val_stride)?;
83 shards.push(map);
84 }
85 Ok(ElasticShardSet {
86 shards,
87 buckets_per_shard,
88 key_stride,
89 val_stride,
90 })
91 }
92
93 pub fn get(&self, key: &K) -> Result<Option<V>> {
98 let idx = self.shard_index(key)?;
99 self.shards[idx].get(key)
100 }
101
102 pub fn put(&mut self, key: &K, value: &V) -> Result<()> {
107 let idx = self.shard_index(key)?;
108 self.shards[idx].insert(key, value)?;
109 Ok(())
110 }
111
112 pub fn remove(&mut self, key: &K) -> Result<bool> {
114 let idx = self.shard_index(key)?;
115 Ok(self.shards[idx].remove(key)?.is_some())
116 }
117
118 pub fn add_shard(&mut self) -> Result<()> {
123 let map =
124 FabricHashMap::with_capacity(self.buckets_per_shard, self.key_stride, self.val_stride)?;
125 self.shards.push(map);
126 Ok(())
127 }
128
129 pub fn remove_shard(&mut self) -> Result<()> {
141 if self.shards.len() <= 1 {
142 return Err(FabricError::CapacityExceeded);
143 }
144 let mut all_entries: Vec<(K, V)> = Vec::new();
146 for shard in &self.shards {
147 for entry in shard.iter() {
148 all_entries.push(entry?);
149 }
150 }
151 self.shards.pop();
153 let new_count = self.shards.len();
155 self.shards.clear();
156 for _ in 0..new_count {
157 let map = FabricHashMap::with_capacity(
158 self.buckets_per_shard,
159 self.key_stride,
160 self.val_stride,
161 )?;
162 self.shards.push(map);
163 }
164 for (k, v) in &all_entries {
166 let idx = self.shard_index(k)?;
167 self.shards[idx].insert(k, v)?;
168 }
169 Ok(())
170 }
171
172 pub fn rehash(&mut self) -> Result<()> {
178 let mut all_entries: Vec<(K, V)> = Vec::new();
180 for shard in &self.shards {
181 for entry in shard.iter() {
182 all_entries.push(entry?);
183 }
184 }
185
186 let shard_count = self.shards.len();
188 self.shards.clear();
189 for _ in 0..shard_count {
190 let map = FabricHashMap::with_capacity(
191 self.buckets_per_shard,
192 self.key_stride,
193 self.val_stride,
194 )?;
195 self.shards.push(map);
196 }
197
198 for (k, v) in &all_entries {
200 let idx = self.shard_index(k)?;
201 self.shards[idx].insert(k, v)?;
202 }
203 Ok(())
204 }
205
206 pub fn shard_count(&self) -> usize {
208 self.shards.len()
209 }
210
211 pub fn len(&self) -> Result<usize> {
213 let mut total = 0;
214 for shard in &self.shards {
215 total += shard.len();
216 }
217 Ok(total)
218 }
219
220 pub fn is_empty(&self) -> Result<bool> {
222 Ok(self.len()? == 0)
223 }
224
225 fn shard_index(&self, key: &K) -> Result<usize> {
227 let key_bytes = postcard::to_allocvec(key).map_err(|_| FabricError::IoError(-1))?;
228 let hash = fnv1a(&key_bytes);
229 Ok((hash as usize) % self.shards.len())
230 }
231}
232
233fn fnv1a(data: &[u8]) -> u64 {
235 let mut hash: u64 = 0xcbf29ce484222325;
236 for &b in data {
237 hash ^= b as u64;
238 hash = hash.wrapping_mul(0x100000001b3);
239 }
240 hash
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use grafos_std::host;
247
248 fn setup() {
249 host::reset_mock();
250 host::mock_set_fbmu_arena_size(65536);
251 }
252
253 #[test]
254 fn create_put_get_remove() {
255 setup();
256 let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(2, 64, 8, 8).expect("new");
257
258 assert!(shards.is_empty().unwrap());
259 shards.put(&1, &100).expect("put 1");
260 shards.put(&2, &200).expect("put 2");
261 shards.put(&3, &300).expect("put 3");
262
263 assert_eq!(shards.len().unwrap(), 3);
264 assert_eq!(shards.get(&1).unwrap(), Some(100));
265 assert_eq!(shards.get(&2).unwrap(), Some(200));
266 assert_eq!(shards.get(&3).unwrap(), Some(300));
267 assert_eq!(shards.get(&99).unwrap(), None);
268
269 assert!(shards.remove(&2).unwrap());
270 assert!(!shards.remove(&99).unwrap());
271 assert_eq!(shards.len().unwrap(), 2);
272 assert_eq!(shards.get(&2).unwrap(), None);
273 }
274
275 #[test]
276 fn correct_shard_routing() {
277 setup();
278 let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(4, 64, 8, 8).expect("new");
279
280 for i in 0..20u32 {
282 shards.put(&i, &(i * 10)).expect("put");
283 }
284 assert_eq!(shards.len().unwrap(), 20);
285
286 for i in 0..20u32 {
287 assert_eq!(shards.get(&i).unwrap(), Some(i * 10), "key {i}");
288 }
289 }
290
291 #[test]
292 fn add_shard_entries_still_accessible() {
293 setup();
294 let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(2, 64, 8, 8).expect("new");
295
296 shards.put(&10, &1000).expect("put");
297 shards.put(&20, &2000).expect("put");
298 shards.put(&30, &3000).expect("put");
299
300 shards.add_shard().expect("add_shard");
301 assert_eq!(shards.shard_count(), 3);
302
303 shards.rehash().expect("rehash");
306 assert_eq!(shards.len().unwrap(), 3);
307 assert_eq!(shards.get(&10).unwrap(), Some(1000));
308 assert_eq!(shards.get(&20).unwrap(), Some(2000));
309 assert_eq!(shards.get(&30).unwrap(), Some(3000));
310 }
311
312 #[test]
313 fn remove_shard_entries_migrated() {
314 setup();
315 let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(3, 64, 8, 8).expect("new");
316
317 for i in 0..10u32 {
318 shards.put(&i, &(i * 100)).expect("put");
319 }
320 assert_eq!(shards.shard_count(), 3);
321 assert_eq!(shards.len().unwrap(), 10);
322
323 shards.remove_shard().expect("remove_shard");
324 assert_eq!(shards.shard_count(), 2);
325 assert_eq!(shards.len().unwrap(), 10);
326
327 for i in 0..10u32 {
328 assert_eq!(shards.get(&i).unwrap(), Some(i * 100), "key {i}");
329 }
330 }
331
332 #[test]
333 fn remove_shard_below_one_fails() {
334 setup();
335 let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(1, 64, 8, 8).expect("new");
336
337 assert_eq!(
338 shards.remove_shard().unwrap_err(),
339 FabricError::CapacityExceeded
340 );
341 }
342
343 #[test]
344 fn rehash_redistributes() {
345 setup();
346 let mut shards: ElasticShardSet<u32, u32> = ElasticShardSet::new(2, 64, 8, 8).expect("new");
347
348 for i in 0..8u32 {
349 shards.put(&i, &(i + 1)).expect("put");
350 }
351
352 shards.add_shard().expect("add");
354 shards.add_shard().expect("add");
355 assert_eq!(shards.shard_count(), 4);
356
357 shards.rehash().expect("rehash");
358 assert_eq!(shards.len().unwrap(), 8);
359
360 for i in 0..8u32 {
362 assert_eq!(shards.get(&i).unwrap(), Some(i + 1), "key {i}");
363 }
364 }
365
366 #[test]
367 fn zero_initial_shards_fails() {
368 setup();
369 let result: Result<ElasticShardSet<u32, u32>> = ElasticShardSet::new(0, 64, 8, 8);
370 match result {
371 Err(FabricError::CapacityExceeded) => {}
372 Err(e) => panic!("expected CapacityExceeded, got {e:?}"),
373 Ok(_) => panic!("expected error, got Ok"),
374 }
375 }
376}