cache/
multi.rs

1//! A cache with multiple providers.
2
3use std::{
4    any::Any,
5    sync::{
6        Arc,
7        mpsc::{Receiver, Sender, channel},
8    },
9};
10
11use crate::{
12    CacheHandle, CacheHandleInner, CacheValueHolder, Cacheable, CacheableWithState, GenerateFn,
13    GenerateResultFn, GenerateResultWithStateFn, GenerateWithStateFn, Namespace, error::ArcResult,
14    mem::NamespaceCache, persistent::client::Client, run_generator,
15};
16
17use serde::{Serialize, de::DeserializeOwned};
18
19/// A cache with multiple providers.
20///
21/// Exposes a unified API for accessing an in-memory [`NamespaceCache`] as well as persistent
22/// cache [`Client`]s.
23#[derive(Default, Debug, Clone)]
24pub struct MultiCache {
25    namespace_cache: Option<NamespaceCache>,
26    clients: Vec<Client>,
27}
28
29/// A builder for a [`MultiCache`].
30#[derive(Default, Debug, Clone)]
31pub struct MultiCacheBuilder {
32    skip_memory: bool,
33    clients: Vec<Client>,
34}
35
36type OptionGenerateHandle<V> = GenerateHandle<V, Option<V>>;
37
38struct GenerateHandle<V, R> {
39    has_value_r: Receiver<Option<CacheHandleInner<V>>>,
40    value_s: Sender<R>,
41    handle: CacheHandleInner<V>,
42}
43
44/// A generate function dispatched to cache provider `C` in order to retrieve a cache handle to a
45/// value that the cache may or may not have, sent over the provided [`Sender`].
46///
47/// The receiver can then be used to recover value that the [`MultiCache`] gets, potentially from
48/// other caches.
49trait MultiGenerateFn<C, K, V, R>:
50    Fn(
51    &mut C,
52    Namespace,
53    Arc<K>,
54    Sender<Option<CacheHandleInner<V>>>,
55    Receiver<R>,
56) -> CacheHandleInner<V>
57{
58}
59impl<
60    C,
61    K,
62    V,
63    R,
64    T: Fn(
65        &mut C,
66        Namespace,
67        Arc<K>,
68        Sender<Option<CacheHandleInner<V>>>,
69        Receiver<R>,
70    ) -> CacheHandleInner<V>,
71> MultiGenerateFn<C, K, V, R> for T
72{
73}
74
75impl MultiCacheBuilder {
76    /// Creates a new [`MultiCacheBuilder`].
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Skips caching results in memory.
82    ///
83    /// With this flag enabled, all cache accesses must go through a cache provider even if key in
84    /// question was accessed earlier by the same process.
85    pub fn skip_memory(&mut self) -> &mut Self {
86        self.skip_memory = true;
87        self
88    }
89
90    /// Adds a new provider to the cache.
91    pub fn with_provider(&mut self, client: Client) -> &mut Self {
92        self.clients.push(client);
93        self
94    }
95
96    /// Builds a [`MultiCache`] from the configured parameters.
97    pub fn build(&mut self) -> MultiCache {
98        MultiCache {
99            namespace_cache: if self.skip_memory {
100                None
101            } else {
102                Some(NamespaceCache::new())
103            },
104            clients: self.clients.clone(),
105        }
106    }
107}
108
109impl MultiCache {
110    /// Creates a [`MultiCache`] with only in-memory providers.
111    pub fn new() -> Self {
112        Self::default()
113    }
114
115    /// Creates a new [`MultiCacheBuilder`].
116    pub fn builder() -> MultiCacheBuilder {
117        MultiCacheBuilder::new()
118    }
119
120    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
121    /// to generate it if it has not already been generated.
122    ///
123    /// Returns a handle to the value. If the value is not yet generated, it is generated
124    /// in the background.
125    ///
126    /// See [`Client::generate`] and [`NamespaceCache::generate`] for related examples.
127    pub fn generate<
128        K: Serialize + Any + Send + Sync,
129        V: Serialize + DeserializeOwned + Send + Sync + Any,
130    >(
131        &mut self,
132        namespace: impl Into<Namespace>,
133        key: K,
134        generate_fn: impl GenerateFn<K, V>,
135    ) -> CacheHandle<V> {
136        CacheHandle::from_inner(Arc::new(self.generate_inner(namespace, key, generate_fn)))
137    }
138
139    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
140    /// to generate it if it has not already been generated.
141    ///
142    /// Returns a handle to the value. If the value is not yet generated, it is generated
143    /// in the background.
144    ///
145    /// See [`Client::generate_with_state`] and [`NamespaceCache::generate_with_state`] for related examples.
146    pub fn generate_with_state<
147        K: Serialize + Send + Sync + Any,
148        S: Send + Sync + Any,
149        V: Serialize + DeserializeOwned + Send + Sync + Any,
150    >(
151        &mut self,
152        namespace: impl Into<Namespace>,
153        key: K,
154        state: S,
155        generate_fn: impl GenerateWithStateFn<K, S, V>,
156    ) -> CacheHandle<V> {
157        let namespace = namespace.into();
158        self.generate(namespace, key, move |k| generate_fn(k, state))
159    }
160
161    /// Ensures that a result corresponding to `key` is generated, using `generate_fn`
162    /// to generate it if it has not already been generated.
163    ///
164    /// Does not cache on failure as errors are not constrained to be serializable/deserializable.
165    /// As such, failures should happen quickly, or should be serializable and stored as part of
166    /// cached value using [`MultiCache::generate`].
167    ///
168    /// Returns a handle to the value. If the value is not yet generated, it is generated
169    ///
170    /// See [`Client::generate_result`] and [`NamespaceCache::generate_result`] for related examples.
171    pub fn generate_result<
172        K: Serialize + Any + Send + Sync,
173        V: Serialize + DeserializeOwned + Send + Sync + Any,
174        E: Send + Sync + Any,
175    >(
176        &mut self,
177        namespace: impl Into<Namespace>,
178        key: K,
179        generate_fn: impl GenerateResultFn<K, V, E>,
180    ) -> CacheHandle<Result<V, E>> {
181        CacheHandle::from_inner(Arc::new(self.generate_result_inner(
182            namespace,
183            key,
184            generate_fn,
185        )))
186    }
187
188    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
189    /// to generate it if it has not already been generated.
190    ///
191    /// Does not cache on failure as errors are not constrained to be serializable/deserializable.
192    /// As such, failures should happen quickly, or should be serializable and stored as part of
193    /// cached value using [`MultiCache::generate_with_state`].
194    ///
195    /// Returns a handle to the value. If the value is not yet generated, it is generated
196    /// in the background.
197    ///
198    /// See [`Client::generate_result_with_state`] and
199    /// [`NamespaceCache::generate_result_with_state`] for related examples.
200    pub fn generate_result_with_state<
201        K: Serialize + Send + Sync + Any,
202        S: Send + Sync + Any,
203        V: Serialize + DeserializeOwned + Send + Sync + Any,
204        E: Send + Sync + Any,
205    >(
206        &mut self,
207        namespace: impl Into<Namespace>,
208        key: K,
209        state: S,
210        generate_fn: impl GenerateResultWithStateFn<K, S, V, E>,
211    ) -> CacheHandle<Result<V, E>> {
212        let namespace = namespace.into();
213        self.generate_result(namespace, key, move |k| generate_fn(k, state))
214    }
215
216    /// Gets a handle to a cacheable object from the cache, generating the object in the background
217    /// if needed.
218    ///
219    /// Does not cache errors, so any errors thrown should be thrown quickly. Any errors that need
220    /// to be cached should be included in the cached output or should be cached using
221    /// [`MultiCache::get_with_err`].
222    ///
223    /// See [`Client::get`] and [`NamespaceCache::get`] for related examples.
224    pub fn get<K: Cacheable>(
225        &mut self,
226        namespace: impl Into<Namespace>,
227        key: K,
228    ) -> CacheHandle<Result<K::Output, K::Error>> {
229        let namespace = namespace.into();
230        self.generate_result(namespace, key, |key| key.generate())
231    }
232
233    /// Gets a handle to a cacheable object from the cache, caching failures as well.
234    ///
235    /// Generates the object in the background if needed.
236    ///
237    /// See [`Client::get_with_err`] and [`NamespaceCache::get_with_err`] for related examples.
238    pub fn get_with_err<
239        E: Send + Sync + Serialize + DeserializeOwned + Any,
240        K: Cacheable<Error = E>,
241    >(
242        &mut self,
243        namespace: impl Into<Namespace>,
244        key: K,
245    ) -> CacheHandle<Result<K::Output, K::Error>> {
246        let namespace = namespace.into();
247        self.generate(namespace, key, |key| key.generate())
248    }
249
250    /// Gets a handle to a cacheable object from the cache, generating the object in the background
251    /// if needed.
252    ///
253    /// Does not cache errors, so any errors thrown should be thrown quickly. Any errors that need
254    /// to be cached should be included in the cached output or should be cached using
255    /// [`MultiCache::get_with_state_and_err`].
256    ///
257    /// See [`Client::get_with_state`] and [`NamespaceCache::get_with_state`] for related examples.
258    pub fn get_with_state<S: Send + Sync + Any, K: CacheableWithState<S>>(
259        &mut self,
260        namespace: impl Into<Namespace>,
261        key: K,
262        state: S,
263    ) -> CacheHandle<Result<K::Output, K::Error>> {
264        let namespace = namespace.into();
265        self.generate_result_with_state(namespace, key, state, |key, state| {
266            key.generate_with_state(state)
267        })
268    }
269
270    /// Gets a handle to a cacheable object from the cache, caching failures as well.
271    ///
272    /// Generates the object in the background if needed.
273    ///
274    /// See [`MultiCache::get_with_err`] and [`MultiCache::get_with_state`] for related examples.
275    pub fn get_with_state_and_err<
276        S: Send + Sync + Any,
277        E: Send + Sync + Serialize + DeserializeOwned + Any,
278        K: CacheableWithState<S, Error = E>,
279    >(
280        &mut self,
281        namespace: impl Into<Namespace>,
282        key: K,
283        state: S,
284    ) -> CacheHandle<Result<K::Output, K::Error>> {
285        let namespace = namespace.into();
286        self.generate_with_state(namespace, key, state, |key, state| {
287            key.generate_with_state(state)
288        })
289    }
290
291    /// Dispatches the provided generate_fn to a cache provider, attempting to recover the cached value in
292    /// the background.
293    fn start_generate<C, K, V: Send + Sync + Any, R>(
294        cache: &mut C,
295        namespace: Namespace,
296        key: Arc<K>,
297        generate_fn: impl MultiGenerateFn<C, K, V, R>,
298    ) -> GenerateHandle<V, R> {
299        let (has_value_s, has_value_r) = channel();
300        let (value_s, value_r) = channel();
301
302        let handle = generate_fn(cache, namespace, key, has_value_s.clone(), value_r);
303
304        let handle_clone = handle.clone();
305        std::thread::spawn(move || {
306            let _ = handle_clone.try_get();
307            let _ = has_value_s.send(Some(handle_clone));
308        });
309
310        GenerateHandle {
311            has_value_r,
312            value_s,
313            handle,
314        }
315    }
316
317    fn generate_inner<
318        K: Serialize + Any + Send + Sync,
319        V: Serialize + DeserializeOwned + Send + Sync + Any,
320    >(
321        &mut self,
322        namespace: impl Into<Namespace>,
323        key: K,
324        generate_fn: impl GenerateFn<K, V>,
325    ) -> CacheHandleInner<V> {
326        let namespace = namespace.into();
327        self.generate_inner_dispatch(
328            namespace,
329            key,
330            generate_fn,
331            |cache, namespace, key, has_value_s, value_r| {
332                cache.generate_inner(namespace, key, move |_| {
333                    let _ = has_value_s.send(None);
334                    value_r.recv().unwrap()
335                })
336            },
337            |cache, namespace, key, has_value_s, value_r| {
338                cache.generate_inner(namespace, key, move |_| {
339                    let _ = has_value_s.send(None);
340                    // Panics if no value is provided. Clients do not cache generator panics.
341                    value_r.recv().unwrap().unwrap()
342                })
343            },
344            MultiCache::recover_value,
345            MultiCache::send_value_to_providers,
346        )
347    }
348
349    fn generate_result_inner<
350        K: Serialize + Any + Send + Sync,
351        V: Serialize + DeserializeOwned + Send + Sync + Any,
352        E: Send + Sync + Any,
353    >(
354        &mut self,
355        namespace: impl Into<Namespace>,
356        key: K,
357        generate_fn: impl GenerateResultFn<K, V, E>,
358    ) -> CacheHandleInner<Result<V, E>> {
359        let namespace = namespace.into();
360        self.generate_inner_dispatch(
361            namespace,
362            key,
363            generate_fn,
364            |cache, namespace, key, has_value_s, value_r| {
365                cache.generate_result_inner(namespace, key, move |_| {
366                    let _ = has_value_s.send(None);
367                    value_r.recv().unwrap()
368                })
369            },
370            |cache, namespace, key, has_value_s, value_r| {
371                cache.generate_result_inner(namespace, key, move |_| {
372                    let _ = has_value_s.send(None);
373                    // Panics if no value is provided. Clients do not cache generator panics.
374                    value_r.recv().unwrap().unwrap()
375                })
376            },
377            MultiCache::recover_result,
378            MultiCache::send_result_to_providers,
379        )
380    }
381
382    #[allow(clippy::too_many_arguments)]
383    fn generate_inner_dispatch<K: Send + Sync + Any, V: Send + Sync + Any>(
384        &mut self,
385        namespace: Namespace,
386        key: K,
387        generate_fn: impl GenerateFn<K, V>,
388        namespace_generate: impl MultiGenerateFn<NamespaceCache, K, V, V>,
389        client_generate: impl MultiGenerateFn<Client, K, V, Option<V>>,
390        recover_value: impl FnOnce(ArcResult<&V>) -> Option<V> + Send + Any,
391        send_value_to_providers: impl Fn(&V, &mut [GenerateHandle<V, Option<V>>]) + Send + Any,
392    ) -> CacheHandleInner<V> {
393        let key = Arc::new(key);
394
395        let mut handle = CacheHandleInner::default();
396        let mut mem_handle = None;
397        let mut client_handles = Vec::new();
398
399        if let Some(cache) = &mut self.namespace_cache {
400            tracing::debug!("dispatching request to in-memory cache");
401            let (namespace, key) = (namespace.clone(), key.clone());
402            let generate_handle =
403                MultiCache::start_generate(cache, namespace, key, namespace_generate);
404            handle = generate_handle.handle.clone();
405            mem_handle = Some(generate_handle);
406        }
407
408        for (i, client) in self.clients.iter_mut().enumerate() {
409            tracing::debug!("dispatching request to persistent client {}", i);
410            let (namespace, key) = (namespace.clone(), key.clone());
411            client_handles.push(MultiCache::start_generate(
412                client,
413                namespace,
414                key,
415                &client_generate,
416            ));
417        }
418
419        let handle_clone = handle.clone();
420
421        tracing::debug!("spawning thread to aggregate results");
422        std::thread::spawn(move || {
423            let mut retrieved_value: Option<V> = None;
424            for (i, has_value_r) in mem_handle
425                .iter()
426                .map(|x| &x.has_value_r)
427                .chain(client_handles.iter().map(|x| &x.has_value_r))
428                .enumerate()
429            {
430                tracing::debug!("waiting on generate handle {}", i);
431                if let Some(value_handle) = has_value_r.recv().unwrap() {
432                    tracing::debug!("received value from generate handle {}", i);
433                    retrieved_value = recover_value(value_handle.try_get());
434                    break;
435                }
436                tracing::debug!(
437                    "did not receive value from generate handle {}, trying next handle",
438                    i
439                );
440            }
441
442            let value = retrieved_value.map(Ok).unwrap_or_else(|| {
443                tracing::debug!("did not receive a value, generating now");
444                run_generator(move || generate_fn(key.as_ref()))
445            });
446
447            if let Ok(value) = value.as_ref() {
448                tracing::debug!("sending generated value to all clients");
449                send_value_to_providers(value, &mut client_handles);
450            }
451
452            // Block until all clients have finished handling the received values.
453            for (i, GenerateHandle { handle, .. }) in client_handles.iter().enumerate() {
454                tracing::debug!("blocking on client {}", i);
455                let _ = handle.try_get();
456            }
457
458            match value {
459                Ok(value) => {
460                    if let Some(mem_handle) = mem_handle {
461                        let _ = mem_handle.value_s.send(value);
462                    } else {
463                        handle_clone.set(Ok(value));
464                    }
465                }
466                e @ Err(_) => handle_clone.set(e),
467            }
468        });
469
470        handle
471    }
472
473    fn recover_value<V: Serialize + DeserializeOwned>(
474        retrieved_result: ArcResult<&V>,
475    ) -> Option<V> {
476        if let Ok(value) = retrieved_result {
477            Some(flexbuffers::from_slice(&flexbuffers::to_vec(value).unwrap()).unwrap())
478        } else {
479            None
480        }
481    }
482
483    fn recover_result<V: Serialize + DeserializeOwned, E>(
484        retrieved_result: ArcResult<&Result<V, E>>,
485    ) -> Option<Result<V, E>> {
486        if let Ok(Ok(value)) = retrieved_result {
487            Some(Ok(flexbuffers::from_slice(
488                &flexbuffers::to_vec(value).unwrap(),
489            )
490            .unwrap()))
491        } else {
492            None
493        }
494    }
495
496    fn send_value_to_providers<V: Serialize + DeserializeOwned>(
497        value: &V,
498        client_handles: &mut [OptionGenerateHandle<V>],
499    ) {
500        for GenerateHandle { value_s, .. } in client_handles.iter_mut() {
501            let _ = value_s.send(Some(
502                flexbuffers::from_slice(&flexbuffers::to_vec(value).unwrap()).unwrap(),
503            ));
504        }
505    }
506
507    fn send_result_to_providers<V: Serialize + DeserializeOwned, E>(
508        value: &Result<V, E>,
509        client_handles: &mut [OptionGenerateHandle<Result<V, E>>],
510    ) {
511        for GenerateHandle { value_s, .. } in client_handles.iter_mut() {
512            if let Ok(value) = value {
513                let _ = value_s.send(Some(Ok(flexbuffers::from_slice(
514                    &flexbuffers::to_vec(value).unwrap(),
515                )
516                .unwrap())));
517            } else {
518                let _ = value_s.send(None);
519            }
520        }
521    }
522}