1use std::{
4 any::Any,
5 sync::{
6 Arc,
7 mpsc::{Receiver, Sender, channel},
8 },
9};
10
11use crate::{
12 CacheHandle, CacheHandleInner, CacheValueHolder, Cacheable, CacheableWithState, GenerateFn,
13 GenerateResultFn, GenerateResultWithStateFn, GenerateWithStateFn, Namespace, error::ArcResult,
14 mem::NamespaceCache, persistent::client::Client, run_generator,
15};
16
17use serde::{Serialize, de::DeserializeOwned};
18
19#[derive(Default, Debug, Clone)]
24pub struct MultiCache {
25 namespace_cache: Option<NamespaceCache>,
26 clients: Vec<Client>,
27}
28
29#[derive(Default, Debug, Clone)]
31pub struct MultiCacheBuilder {
32 skip_memory: bool,
33 clients: Vec<Client>,
34}
35
36type OptionGenerateHandle<V> = GenerateHandle<V, Option<V>>;
37
38struct GenerateHandle<V, R> {
39 has_value_r: Receiver<Option<CacheHandleInner<V>>>,
40 value_s: Sender<R>,
41 handle: CacheHandleInner<V>,
42}
43
44trait MultiGenerateFn<C, K, V, R>:
50 Fn(
51 &mut C,
52 Namespace,
53 Arc<K>,
54 Sender<Option<CacheHandleInner<V>>>,
55 Receiver<R>,
56) -> CacheHandleInner<V>
57{
58}
59impl<
60 C,
61 K,
62 V,
63 R,
64 T: Fn(
65 &mut C,
66 Namespace,
67 Arc<K>,
68 Sender<Option<CacheHandleInner<V>>>,
69 Receiver<R>,
70 ) -> CacheHandleInner<V>,
71> MultiGenerateFn<C, K, V, R> for T
72{
73}
74
75impl MultiCacheBuilder {
76 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn skip_memory(&mut self) -> &mut Self {
86 self.skip_memory = true;
87 self
88 }
89
90 pub fn with_provider(&mut self, client: Client) -> &mut Self {
92 self.clients.push(client);
93 self
94 }
95
96 pub fn build(&mut self) -> MultiCache {
98 MultiCache {
99 namespace_cache: if self.skip_memory {
100 None
101 } else {
102 Some(NamespaceCache::new())
103 },
104 clients: self.clients.clone(),
105 }
106 }
107}
108
109impl MultiCache {
110 pub fn new() -> Self {
112 Self::default()
113 }
114
115 pub fn builder() -> MultiCacheBuilder {
117 MultiCacheBuilder::new()
118 }
119
120 pub fn generate<
128 K: Serialize + Any + Send + Sync,
129 V: Serialize + DeserializeOwned + Send + Sync + Any,
130 >(
131 &mut self,
132 namespace: impl Into<Namespace>,
133 key: K,
134 generate_fn: impl GenerateFn<K, V>,
135 ) -> CacheHandle<V> {
136 CacheHandle::from_inner(Arc::new(self.generate_inner(namespace, key, generate_fn)))
137 }
138
139 pub fn generate_with_state<
147 K: Serialize + Send + Sync + Any,
148 S: Send + Sync + Any,
149 V: Serialize + DeserializeOwned + Send + Sync + Any,
150 >(
151 &mut self,
152 namespace: impl Into<Namespace>,
153 key: K,
154 state: S,
155 generate_fn: impl GenerateWithStateFn<K, S, V>,
156 ) -> CacheHandle<V> {
157 let namespace = namespace.into();
158 self.generate(namespace, key, move |k| generate_fn(k, state))
159 }
160
161 pub fn generate_result<
172 K: Serialize + Any + Send + Sync,
173 V: Serialize + DeserializeOwned + Send + Sync + Any,
174 E: Send + Sync + Any,
175 >(
176 &mut self,
177 namespace: impl Into<Namespace>,
178 key: K,
179 generate_fn: impl GenerateResultFn<K, V, E>,
180 ) -> CacheHandle<Result<V, E>> {
181 CacheHandle::from_inner(Arc::new(self.generate_result_inner(
182 namespace,
183 key,
184 generate_fn,
185 )))
186 }
187
188 pub fn generate_result_with_state<
201 K: Serialize + Send + Sync + Any,
202 S: Send + Sync + Any,
203 V: Serialize + DeserializeOwned + Send + Sync + Any,
204 E: Send + Sync + Any,
205 >(
206 &mut self,
207 namespace: impl Into<Namespace>,
208 key: K,
209 state: S,
210 generate_fn: impl GenerateResultWithStateFn<K, S, V, E>,
211 ) -> CacheHandle<Result<V, E>> {
212 let namespace = namespace.into();
213 self.generate_result(namespace, key, move |k| generate_fn(k, state))
214 }
215
216 pub fn get<K: Cacheable>(
225 &mut self,
226 namespace: impl Into<Namespace>,
227 key: K,
228 ) -> CacheHandle<Result<K::Output, K::Error>> {
229 let namespace = namespace.into();
230 self.generate_result(namespace, key, |key| key.generate())
231 }
232
233 pub fn get_with_err<
239 E: Send + Sync + Serialize + DeserializeOwned + Any,
240 K: Cacheable<Error = E>,
241 >(
242 &mut self,
243 namespace: impl Into<Namespace>,
244 key: K,
245 ) -> CacheHandle<Result<K::Output, K::Error>> {
246 let namespace = namespace.into();
247 self.generate(namespace, key, |key| key.generate())
248 }
249
250 pub fn get_with_state<S: Send + Sync + Any, K: CacheableWithState<S>>(
259 &mut self,
260 namespace: impl Into<Namespace>,
261 key: K,
262 state: S,
263 ) -> CacheHandle<Result<K::Output, K::Error>> {
264 let namespace = namespace.into();
265 self.generate_result_with_state(namespace, key, state, |key, state| {
266 key.generate_with_state(state)
267 })
268 }
269
270 pub fn get_with_state_and_err<
276 S: Send + Sync + Any,
277 E: Send + Sync + Serialize + DeserializeOwned + Any,
278 K: CacheableWithState<S, Error = E>,
279 >(
280 &mut self,
281 namespace: impl Into<Namespace>,
282 key: K,
283 state: S,
284 ) -> CacheHandle<Result<K::Output, K::Error>> {
285 let namespace = namespace.into();
286 self.generate_with_state(namespace, key, state, |key, state| {
287 key.generate_with_state(state)
288 })
289 }
290
291 fn start_generate<C, K, V: Send + Sync + Any, R>(
294 cache: &mut C,
295 namespace: Namespace,
296 key: Arc<K>,
297 generate_fn: impl MultiGenerateFn<C, K, V, R>,
298 ) -> GenerateHandle<V, R> {
299 let (has_value_s, has_value_r) = channel();
300 let (value_s, value_r) = channel();
301
302 let handle = generate_fn(cache, namespace, key, has_value_s.clone(), value_r);
303
304 let handle_clone = handle.clone();
305 std::thread::spawn(move || {
306 let _ = handle_clone.try_get();
307 let _ = has_value_s.send(Some(handle_clone));
308 });
309
310 GenerateHandle {
311 has_value_r,
312 value_s,
313 handle,
314 }
315 }
316
317 fn generate_inner<
318 K: Serialize + Any + Send + Sync,
319 V: Serialize + DeserializeOwned + Send + Sync + Any,
320 >(
321 &mut self,
322 namespace: impl Into<Namespace>,
323 key: K,
324 generate_fn: impl GenerateFn<K, V>,
325 ) -> CacheHandleInner<V> {
326 let namespace = namespace.into();
327 self.generate_inner_dispatch(
328 namespace,
329 key,
330 generate_fn,
331 |cache, namespace, key, has_value_s, value_r| {
332 cache.generate_inner(namespace, key, move |_| {
333 let _ = has_value_s.send(None);
334 value_r.recv().unwrap()
335 })
336 },
337 |cache, namespace, key, has_value_s, value_r| {
338 cache.generate_inner(namespace, key, move |_| {
339 let _ = has_value_s.send(None);
340 value_r.recv().unwrap().unwrap()
342 })
343 },
344 MultiCache::recover_value,
345 MultiCache::send_value_to_providers,
346 )
347 }
348
349 fn generate_result_inner<
350 K: Serialize + Any + Send + Sync,
351 V: Serialize + DeserializeOwned + Send + Sync + Any,
352 E: Send + Sync + Any,
353 >(
354 &mut self,
355 namespace: impl Into<Namespace>,
356 key: K,
357 generate_fn: impl GenerateResultFn<K, V, E>,
358 ) -> CacheHandleInner<Result<V, E>> {
359 let namespace = namespace.into();
360 self.generate_inner_dispatch(
361 namespace,
362 key,
363 generate_fn,
364 |cache, namespace, key, has_value_s, value_r| {
365 cache.generate_result_inner(namespace, key, move |_| {
366 let _ = has_value_s.send(None);
367 value_r.recv().unwrap()
368 })
369 },
370 |cache, namespace, key, has_value_s, value_r| {
371 cache.generate_result_inner(namespace, key, move |_| {
372 let _ = has_value_s.send(None);
373 value_r.recv().unwrap().unwrap()
375 })
376 },
377 MultiCache::recover_result,
378 MultiCache::send_result_to_providers,
379 )
380 }
381
382 #[allow(clippy::too_many_arguments)]
383 fn generate_inner_dispatch<K: Send + Sync + Any, V: Send + Sync + Any>(
384 &mut self,
385 namespace: Namespace,
386 key: K,
387 generate_fn: impl GenerateFn<K, V>,
388 namespace_generate: impl MultiGenerateFn<NamespaceCache, K, V, V>,
389 client_generate: impl MultiGenerateFn<Client, K, V, Option<V>>,
390 recover_value: impl FnOnce(ArcResult<&V>) -> Option<V> + Send + Any,
391 send_value_to_providers: impl Fn(&V, &mut [GenerateHandle<V, Option<V>>]) + Send + Any,
392 ) -> CacheHandleInner<V> {
393 let key = Arc::new(key);
394
395 let mut handle = CacheHandleInner::default();
396 let mut mem_handle = None;
397 let mut client_handles = Vec::new();
398
399 if let Some(cache) = &mut self.namespace_cache {
400 tracing::debug!("dispatching request to in-memory cache");
401 let (namespace, key) = (namespace.clone(), key.clone());
402 let generate_handle =
403 MultiCache::start_generate(cache, namespace, key, namespace_generate);
404 handle = generate_handle.handle.clone();
405 mem_handle = Some(generate_handle);
406 }
407
408 for (i, client) in self.clients.iter_mut().enumerate() {
409 tracing::debug!("dispatching request to persistent client {}", i);
410 let (namespace, key) = (namespace.clone(), key.clone());
411 client_handles.push(MultiCache::start_generate(
412 client,
413 namespace,
414 key,
415 &client_generate,
416 ));
417 }
418
419 let handle_clone = handle.clone();
420
421 tracing::debug!("spawning thread to aggregate results");
422 std::thread::spawn(move || {
423 let mut retrieved_value: Option<V> = None;
424 for (i, has_value_r) in mem_handle
425 .iter()
426 .map(|x| &x.has_value_r)
427 .chain(client_handles.iter().map(|x| &x.has_value_r))
428 .enumerate()
429 {
430 tracing::debug!("waiting on generate handle {}", i);
431 if let Some(value_handle) = has_value_r.recv().unwrap() {
432 tracing::debug!("received value from generate handle {}", i);
433 retrieved_value = recover_value(value_handle.try_get());
434 break;
435 }
436 tracing::debug!(
437 "did not receive value from generate handle {}, trying next handle",
438 i
439 );
440 }
441
442 let value = retrieved_value.map(Ok).unwrap_or_else(|| {
443 tracing::debug!("did not receive a value, generating now");
444 run_generator(move || generate_fn(key.as_ref()))
445 });
446
447 if let Ok(value) = value.as_ref() {
448 tracing::debug!("sending generated value to all clients");
449 send_value_to_providers(value, &mut client_handles);
450 }
451
452 for (i, GenerateHandle { handle, .. }) in client_handles.iter().enumerate() {
454 tracing::debug!("blocking on client {}", i);
455 let _ = handle.try_get();
456 }
457
458 match value {
459 Ok(value) => {
460 if let Some(mem_handle) = mem_handle {
461 let _ = mem_handle.value_s.send(value);
462 } else {
463 handle_clone.set(Ok(value));
464 }
465 }
466 e @ Err(_) => handle_clone.set(e),
467 }
468 });
469
470 handle
471 }
472
473 fn recover_value<V: Serialize + DeserializeOwned>(
474 retrieved_result: ArcResult<&V>,
475 ) -> Option<V> {
476 if let Ok(value) = retrieved_result {
477 Some(flexbuffers::from_slice(&flexbuffers::to_vec(value).unwrap()).unwrap())
478 } else {
479 None
480 }
481 }
482
483 fn recover_result<V: Serialize + DeserializeOwned, E>(
484 retrieved_result: ArcResult<&Result<V, E>>,
485 ) -> Option<Result<V, E>> {
486 if let Ok(Ok(value)) = retrieved_result {
487 Some(Ok(flexbuffers::from_slice(
488 &flexbuffers::to_vec(value).unwrap(),
489 )
490 .unwrap()))
491 } else {
492 None
493 }
494 }
495
496 fn send_value_to_providers<V: Serialize + DeserializeOwned>(
497 value: &V,
498 client_handles: &mut [OptionGenerateHandle<V>],
499 ) {
500 for GenerateHandle { value_s, .. } in client_handles.iter_mut() {
501 let _ = value_s.send(Some(
502 flexbuffers::from_slice(&flexbuffers::to_vec(value).unwrap()).unwrap(),
503 ));
504 }
505 }
506
507 fn send_result_to_providers<V: Serialize + DeserializeOwned, E>(
508 value: &Result<V, E>,
509 client_handles: &mut [OptionGenerateHandle<Result<V, E>>],
510 ) {
511 for GenerateHandle { value_s, .. } in client_handles.iter_mut() {
512 if let Ok(value) = value {
513 let _ = value_s.send(Some(Ok(flexbuffers::from_slice(
514 &flexbuffers::to_vec(value).unwrap(),
515 )
516 .unwrap())));
517 } else {
518 let _ = value_s.send(None);
519 }
520 }
521 }
522}