cache/persistent/
client.rs

1//! A client for interacting with a cache server.
2
3use std::{
4    any::Any,
5    fs::{self, OpenOptions},
6    io::{Read, Write},
7    net::TcpListener,
8    path::{Path, PathBuf},
9    sync::{
10        Arc, Mutex,
11        mpsc::{Receiver, RecvTimeoutError, Sender, channel},
12    },
13    thread,
14    time::Duration,
15};
16
17use backoff::ExponentialBackoff;
18use serde::{Deserialize, Serialize, de::DeserializeOwned};
19use tokio::runtime::{Handle, Runtime};
20use tonic::transport::{Channel, Endpoint};
21
22use crate::{
23    CacheHandle, CacheHandleInner, Cacheable, CacheableWithState, GenerateFn, GenerateResultFn,
24    GenerateResultWithStateFn, GenerateWithStateFn, Namespace,
25    error::{ArcResult, Error, Result},
26    rpc::{
27        local::{self, local_cache_client},
28        remote::{self, remote_cache_client},
29    },
30    run_generator,
31};
32
33use super::server::Server;
34
35/// The timeout for connecting to the cache server.
36pub const CONNECTION_TIMEOUT_MS_DEFAULT: u64 = 1000;
37
38/// The timeout for making a request to the cache server.
39pub const REQUEST_TIMEOUT_MS_DEFAULT: u64 = 1000;
40
41/// An enumeration of client kinds.
42///
43/// Each interacts with a different cache server API, depending on the desired functionality.
44#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
45pub enum ClientKind {
46    /// A client that shares a filesystem with the server.
47    ///
48    /// Enables storing data in the cache via the filesystem without sending large bytestreams over gRPC.
49    Local,
50    /// A client that does not share a filseystem with the server.
51    ///
52    /// Sends data to the cache server via gRPC.
53    Remote,
54}
55
56#[derive(Debug)]
57struct ClientInner {
58    kind: ClientKind,
59    url: String,
60    poll_backoff: ExponentialBackoff,
61    connection_timeout: Duration,
62    request_timeout: Duration,
63    handle: Handle,
64    // Only used to own a runtime created by the builder.
65    #[allow(dead_code)]
66    runtime: Option<Runtime>,
67}
68
69/// A gRPC cache client.
70///
71/// The semantics of the [`Client`] API are the same as those of the
72/// [`NamespaceCache`](crate::mem::NamespaceCache) API.
73#[derive(Debug, Clone)]
74pub struct Client {
75    inner: Arc<ClientInner>,
76}
77
78/// A builder for a [`Client`].
79#[derive(Default, Clone, Debug)]
80pub struct ClientBuilder {
81    kind: Option<ClientKind>,
82    url: Option<String>,
83    poll_backoff: Option<ExponentialBackoff>,
84    connection_timeout: Option<Duration>,
85    request_timeout: Option<Duration>,
86    handle: Option<Handle>,
87}
88
89struct GenerateState<K, V> {
90    handle: CacheHandleInner<V>,
91    namespace: Namespace,
92    hash: Vec<u8>,
93    key: K,
94}
95
96/// Sends a heartbeat RPC to the server.
97trait HeartbeatFn: Fn(&Client) -> Result<()> + Send + Any {}
98impl<T: Fn(&Client) -> Result<()> + Send + Any> HeartbeatFn for T {}
99
100/// Writes a generated value to the given `String` path, using the provided assignment ID `u64` to
101/// notify the cache server once completed.
102trait LocalWriteValueFn<V>:
103    FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any
104{
105}
106impl<V, T: FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any>
107    LocalWriteValueFn<V> for T
108{
109}
110
111/// Writes a generated value to the cache server, using the provided assignment ID `u64` to
112/// tell the cache server which task completed.
113trait RemoteWriteValueFn<V>: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any {}
114impl<V, T: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any> RemoteWriteValueFn<V>
115    for T
116{
117}
118
119/// Deserializes desired value from bytes stored in the cache. If `V` is a result, would need to
120/// wrap the bytes from the cache with an `Ok` since `Err` results are not stored in the cache.
121trait DeserializeValueFn<V>: FnOnce(&[u8]) -> Result<V> + Send + Any {}
122impl<V, T: FnOnce(&[u8]) -> Result<V> + Send + Any> DeserializeValueFn<V> for T {}
123
124impl ClientBuilder {
125    /// Creates a new [`ClientBuilder`].
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    /// Sets the configured server URL.
131    pub fn url(&mut self, url: impl Into<String>) -> &mut Self {
132        self.url = Some(url.into());
133        self
134    }
135
136    /// Sets the configured client type.
137    pub fn kind(&mut self, kind: ClientKind) -> &mut Self {
138        self.kind = Some(kind);
139        self
140    }
141    /// Creates a new [`ClientBuilder`] with configured client type [`ClientKind::Local`] and a
142    /// server URL `url`.
143    pub fn local(url: impl Into<String>) -> Self {
144        let mut builder = Self::new();
145        builder.kind(ClientKind::Local).url(url);
146        builder
147    }
148
149    /// Creates a new [`ClientBuilder`] with configured client type [`ClientKind::Remote`] and a
150    /// server URL `url`.
151    pub fn remote(url: impl Into<String>) -> Self {
152        let mut builder = Self::new();
153        builder.kind(ClientKind::Remote).url(url);
154        builder
155    }
156
157    /// Configures the exponential backoff used when polling the server for cache entry
158    /// statuses.
159    ///
160    /// Defaults to [`ExponentialBackoff::default`].
161    pub fn poll_backoff(&mut self, backoff: ExponentialBackoff) -> &mut Self {
162        self.poll_backoff = Some(backoff);
163        self
164    }
165
166    /// Sets the timeout for connecting to the server.
167    ///
168    /// Defaults to [`CONNECTION_TIMEOUT_MS_DEFAULT`].
169    pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
170        self.connection_timeout = Some(timeout);
171        self
172    }
173
174    /// Sets the timeout for receiving a reply from the server.
175    ///
176    /// Defaults to [`REQUEST_TIMEOUT_MS_DEFAULT`].
177    pub fn request_timeout(&mut self, timeout: Duration) -> &mut Self {
178        self.request_timeout = Some(timeout);
179        self
180    }
181
182    /// Configures a [`Handle`] for making asynchronous gRPC requests.
183    ///
184    /// If no handle is specified, starts a new [`tokio::runtime::Runtime`] upon building the
185    /// [`Client`] object.
186    pub fn runtime_handle(&mut self, handle: Handle) -> &mut Self {
187        self.handle = Some(handle);
188        self
189    }
190
191    /// Builds a [`Client`] object with the configured parameters.
192    pub fn build(&mut self) -> Client {
193        let (handle, runtime) = match self.handle.clone() {
194            Some(handle) => (handle, None),
195            None => {
196                let runtime = tokio::runtime::Builder::new_multi_thread()
197                    .worker_threads(1)
198                    .enable_all()
199                    .build()
200                    .unwrap();
201                (runtime.handle().clone(), Some(runtime))
202            }
203        };
204        Client {
205            inner: Arc::new(ClientInner {
206                kind: self.kind.expect("must specify client kind"),
207                url: self.url.clone().expect("must specify server URL"),
208                poll_backoff: self.poll_backoff.clone().unwrap_or_default(),
209                connection_timeout: self
210                    .connection_timeout
211                    .unwrap_or(Duration::from_millis(CONNECTION_TIMEOUT_MS_DEFAULT)),
212                request_timeout: self
213                    .request_timeout
214                    .unwrap_or(Duration::from_millis(REQUEST_TIMEOUT_MS_DEFAULT)),
215                handle,
216                runtime,
217            }),
218        }
219    }
220}
221
222impl Client {
223    /// Creates a new gRPC cache client for a server at `url` with default configuration values.
224    pub fn with_default_config(kind: ClientKind, url: impl Into<String>) -> Self {
225        Self::builder().kind(kind).url(url).build()
226    }
227
228    /// Creates a new gRPC cache client.
229    pub fn builder() -> ClientBuilder {
230        ClientBuilder::new()
231    }
232
233    /// Creates a new local gRPC cache client.
234    ///
235    /// See [`ClientKind`] for an explanation of the different kinds of clients.
236    pub fn local(url: impl Into<String>) -> ClientBuilder {
237        ClientBuilder::local(url)
238    }
239
240    /// Creates a new remote gRPC cache client.
241    ///
242    /// See [`ClientKind`] for an explanation of the different kinds of clients.
243    pub fn remote(url: impl Into<String>) -> ClientBuilder {
244        ClientBuilder::remote(url)
245    }
246
247    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
248    /// to generate it if it has not already been generated.
249    ///
250    /// Returns a handle to the value. If the value is not yet generated, it is generated
251    /// in the background.
252    ///
253    /// For more detailed examples, refer to
254    /// [`NamespaceCache::generate`](crate::mem::NamespaceCache::generate).
255    ///
256    /// # Examples
257    ///
258    /// ```
259    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
260    ///
261    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
262    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
263    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate").unwrap();
264    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
265    ///
266    /// fn generate_fn(tuple: &(u64, u64)) -> u64 {
267    ///     tuple.0 + tuple.1
268    /// }
269    ///
270    /// let handle = client.generate("example.namespace", (5, 6), generate_fn);
271    /// assert_eq!(*handle.get(), 11);
272    /// ```
273    pub fn generate<
274        K: Serialize + Any + Send + Sync,
275        V: Serialize + DeserializeOwned + Send + Sync + Any,
276    >(
277        &self,
278        namespace: impl Into<Namespace>,
279        key: K,
280        generate_fn: impl GenerateFn<K, V>,
281    ) -> CacheHandle<V> {
282        CacheHandle::from_inner(Arc::new(self.generate_inner(namespace, key, generate_fn)))
283    }
284
285    pub(crate) fn generate_inner<
286        K: Serialize + Any + Send + Sync,
287        V: Serialize + DeserializeOwned + Send + Sync + Any,
288    >(
289        &self,
290        namespace: impl Into<Namespace>,
291        key: K,
292        generate_fn: impl GenerateFn<K, V>,
293    ) -> CacheHandleInner<V> {
294        let namespace = namespace.into();
295        let state = Client::setup_generate(namespace, key);
296        let handle = state.handle.clone();
297
298        match self.inner.kind {
299            ClientKind::Local => self.clone().generate_inner_local(state, generate_fn),
300            ClientKind::Remote => self.clone().generate_inner_remote(state, generate_fn),
301        }
302
303        handle
304    }
305
306    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
307    /// to generate it if it has not already been generated.
308    ///
309    /// Returns a handle to the value. If the value is not yet generated, it is generated
310    /// in the background.
311    ///
312    /// For more detailed examples, refer to
313    /// [`NamespaceCache::generate_with_state`](crate::mem::NamespaceCache::generate_with_state).
314    ///
315    /// # Examples
316    ///
317    /// ```
318    /// use std::sync::{Arc, Mutex};
319    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
320    ///
321    /// #[derive(Clone)]
322    /// pub struct Log(Arc<Mutex<Vec<(u64, u64)>>>);
323    ///
324    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
325    /// let log = Log(Arc::new(Mutex::new(Vec::new())));
326    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
327    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate_with_state").unwrap();
328    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
329    ///
330    /// fn generate_fn(tuple: &(u64, u64), state: Log) -> u64 {
331    ///     println!("Logging parameters...");
332    ///     state.0.lock().unwrap().push(*tuple);
333    ///     tuple.0 + tuple.1
334    /// }
335    ///
336    /// let handle = client.generate_with_state(
337    ///     "example.namespace", (5, 6), log.clone(), generate_fn
338    /// );
339    /// assert_eq!(*handle.get(), 11);
340    /// assert_eq!(log.0.lock().unwrap().clone(), vec![(5, 6)]);
341    /// ```
342    pub fn generate_with_state<
343        K: Serialize + Send + Sync + Any,
344        V: Serialize + DeserializeOwned + Send + Sync + Any,
345        S: Send + Sync + Any,
346    >(
347        &self,
348        namespace: impl Into<Namespace>,
349        key: K,
350        state: S,
351        generate_fn: impl GenerateWithStateFn<K, S, V>,
352    ) -> CacheHandle<V> {
353        let namespace = namespace.into();
354        self.generate(namespace, key, move |k| generate_fn(k, state))
355    }
356
357    /// Ensures that a result corresponding to `key` is generated, using `generate_fn`
358    /// to generate it if it has not already been generated.
359    ///
360    /// Does not cache on failure as errors are not constrained to be serializable/deserializable.
361    /// As such, failures should happen quickly, or should be serializable and stored as part of
362    /// cached value using [`Client::generate`].
363    ///
364    /// Returns a handle to the value. If the value is not yet generated, it is generated in the
365    /// background.
366    ///
367    /// For more detailed examples, refer to
368    /// [`NamespaceCache::generate_result`](crate::mem::NamespaceCache::generate_result).
369    ///
370    /// # Examples
371    ///
372    /// ```
373    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
374    ///
375    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
376    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
377    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate_result").unwrap();
378    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
379    ///
380    /// fn generate_fn(tuple: &(u64, u64)) -> anyhow::Result<u64> {
381    ///     if *tuple == (5, 5) {
382    ///         Err(anyhow::anyhow!("invalid tuple"))
383    ///     } else {
384    ///         Ok(tuple.0 + tuple.1)
385    ///     }
386    /// }
387    ///
388    /// let handle = client.generate_result("example.namespace", (5, 5), generate_fn);
389    /// assert_eq!(format!("{}", handle.unwrap_err_inner().root_cause()), "invalid tuple");
390    /// ```
391    pub fn generate_result<
392        K: Serialize + Any + Send + Sync,
393        V: Serialize + DeserializeOwned + Send + Sync + Any,
394        E: Send + Sync + Any,
395    >(
396        &self,
397        namespace: impl Into<Namespace>,
398        key: K,
399        generate_fn: impl GenerateResultFn<K, V, E>,
400    ) -> CacheHandle<std::result::Result<V, E>> {
401        CacheHandle::from_inner(Arc::new(self.generate_result_inner(
402            namespace,
403            key,
404            generate_fn,
405        )))
406    }
407
408    pub(crate) fn generate_result_inner<
409        K: Serialize + Any + Send + Sync,
410        V: Serialize + DeserializeOwned + Send + Sync + Any,
411        E: Send + Sync + Any,
412    >(
413        &self,
414        namespace: impl Into<Namespace>,
415        key: K,
416        generate_fn: impl GenerateResultFn<K, V, E>,
417    ) -> CacheHandleInner<std::result::Result<V, E>> {
418        let namespace = namespace.into();
419        let state = Client::setup_generate(namespace, key);
420        let handle = state.handle.clone();
421
422        match self.inner.kind {
423            ClientKind::Local => {
424                self.clone().generate_result_inner_local(state, generate_fn);
425            }
426            ClientKind::Remote => {
427                self.clone()
428                    .generate_result_inner_remote(state, generate_fn);
429            }
430        }
431
432        handle
433    }
434
435    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
436    /// to generate it if it has not already been generated.
437    ///
438    /// Does not cache on failure as errors are not constrained to be serializable/deserializable.
439    /// As such, failures should happen quickly, or should be serializable and stored as part of
440    /// cached value using [`Client::generate_with_state`].
441    ///
442    /// Returns a handle to the value. If the value is not yet generated, it is generated
443    /// in the background.
444    ///
445    /// For more detailed examples, refer to
446    /// [`NamespaceCache::generate_result_with_state`](crate::mem::NamespaceCache::generate_result_with_state).
447    ///
448    /// # Examples
449    ///
450    /// ```
451    /// use std::sync::{Arc, Mutex};
452    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
453    ///
454    /// #[derive(Clone)]
455    /// pub struct Log(Arc<Mutex<Vec<(u64, u64)>>>);
456    ///
457    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
458    /// let log = Log(Arc::new(Mutex::new(Vec::new())));
459    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
460    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate_result_with_state").unwrap();
461    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
462    ///
463    /// fn generate_fn(tuple: &(u64, u64), state: Log) -> anyhow::Result<u64> {
464    ///     println!("Logging parameters...");
465    ///     state.0.lock().unwrap().push(*tuple);
466    ///
467    ///     if *tuple == (5, 5) {
468    ///         Err(anyhow::anyhow!("invalid tuple"))
469    ///     } else {
470    ///         Ok(tuple.0 + tuple.1)
471    ///     }
472    /// }
473    ///
474    /// let handle = client.generate_result_with_state(
475    ///     "example.namespace", (5, 5), log.clone(), generate_fn,
476    /// );
477    /// assert_eq!(format!("{}", handle.unwrap_err_inner().root_cause()), "invalid tuple");
478    /// assert_eq!(log.0.lock().unwrap().clone(), vec![(5, 5)]);
479    /// ```
480    pub fn generate_result_with_state<
481        K: Serialize + Send + Sync + Any,
482        V: Serialize + DeserializeOwned + Send + Sync + Any,
483        E: Send + Sync + Any,
484        S: Send + Sync + Any,
485    >(
486        &self,
487        namespace: impl Into<Namespace>,
488        key: K,
489        state: S,
490        generate_fn: impl GenerateResultWithStateFn<K, S, V, E>,
491    ) -> CacheHandle<std::result::Result<V, E>> {
492        let namespace = namespace.into();
493        self.generate_result(namespace, key, move |k| generate_fn(k, state))
494    }
495
496    /// Gets a handle to a cacheable object from the cache, generating the object in the background
497    /// if needed.
498    ///
499    /// Does not cache errors, so any errors thrown should be thrown quickly. Any errors that need
500    /// to be cached should be included in the cached output or should be cached using
501    /// [`Client::get_with_err`].
502    ///
503    /// For more detailed examples, refer to
504    /// [`NamespaceCache::get`](crate::mem::NamespaceCache::get).
505    ///
506    /// # Examples
507    ///
508    /// ```
509    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
510    /// use serde::{Deserialize, Serialize};
511    ///
512    /// #[derive(Deserialize, Serialize, Hash, Eq, PartialEq)]
513    /// pub struct Params {
514    ///     param1: u64,
515    ///     param2: String,
516    /// };
517    ///
518    /// impl Cacheable for Params {
519    ///     type Output = u64;
520    ///     type Error = anyhow::Error;
521    ///
522    ///     fn generate(&self) -> anyhow::Result<u64> {
523    ///         Ok(2 * self.param1)
524    ///     }
525    /// }
526    ///
527    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
528    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
529    /// # let (root, _, runtime) = setup_test("persistent_client_Client_get").unwrap();
530    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
531    ///
532    /// let handle = client.get(
533    ///     "example.namespace", Params { param1: 50, param2: "cache".to_string() }
534    /// );
535    /// assert_eq!(*handle.unwrap_inner(), 100);
536    /// ```
537    pub fn get<K: Cacheable>(
538        &self,
539        namespace: impl Into<Namespace>,
540        key: K,
541    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
542        let namespace = namespace.into();
543        self.generate_result(namespace, key, |key| key.generate())
544    }
545
546    /// Gets a handle to a cacheable object from the cache, caching failures as well.
547    ///
548    /// Generates the object in the background if needed.
549    ///
550    /// For more detailed examples, refer to
551    /// [`NamespaceCache::get_with_err`](crate::mem::NamespaceCache::get_with_err).
552    ///
553    /// # Examples
554    ///
555    /// ```
556    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
557    /// use serde::{Deserialize, Serialize};
558    ///
559    /// #[derive(Deserialize, Serialize, Hash, Eq, PartialEq)]
560    /// pub struct Params {
561    ///     param1: u64,
562    ///     param2: String,
563    /// };
564    ///
565    /// impl Cacheable for Params {
566    ///     type Output = u64;
567    ///     type Error = String;
568    ///
569    ///     fn generate(&self) -> Result<Self::Output, Self::Error> {
570    ///         if self.param1 == 5 {
571    ///             Err("invalid param".to_string())
572    ///         } else {
573    ///             Ok(2 * self.param1)
574    ///         }
575    ///     }
576    /// }
577    ///
578    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
579    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
580    /// # let (root, _, runtime) = setup_test("persistent_client_Client_get_with_err").unwrap();
581    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
582    ///
583    /// let handle = client.get_with_err(
584    ///     "example.namespace", Params { param1: 5, param2: "cache".to_string() }
585    /// );
586    /// assert_eq!(handle.unwrap_err_inner(), "invalid param");
587    /// ```
588    pub fn get_with_err<
589        E: Send + Sync + Serialize + DeserializeOwned + Any,
590        K: Cacheable<Error = E>,
591    >(
592        &self,
593        namespace: impl Into<Namespace>,
594        key: K,
595    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
596        let namespace = namespace.into();
597        self.generate(namespace, key, |key| key.generate())
598    }
599
600    /// Gets a handle to a cacheable object from the cache, generating the object in the background
601    /// if needed.
602    ///
603    /// Does not cache errors, so any errors thrown should be thrown quickly. Any errors that need
604    /// to be cached should be included in the cached output or should be cached using
605    /// [`Client::get_with_state_and_err`].
606    ///
607    /// For more detailed examples, refer to
608    /// [`NamespaceCache::get_with_state`](crate::mem::NamespaceCache::get_with_state).
609    ///
610    /// # Examples
611    ///
612    /// ```
613    /// use std::sync::{Arc, Mutex};
614    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, CacheableWithState};
615    /// use serde::{Deserialize, Serialize};
616    ///
617    /// #[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
618    /// pub struct Params(u64);
619    ///
620    /// #[derive(Clone)]
621    /// pub struct Log(Arc<Mutex<Vec<Params>>>);
622    ///
623    /// impl CacheableWithState<Log> for Params {
624    ///     type Output = u64;
625    ///     type Error = anyhow::Error;
626    ///
627    ///     fn generate_with_state(&self, state: Log) -> anyhow::Result<u64> {
628    ///         println!("Logging parameters...");
629    ///         state.0.lock().unwrap().push(self.clone());
630    ///         Ok(2 * self.0)
631    ///     }
632    /// }
633    ///
634    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
635    /// let log = Log(Arc::new(Mutex::new(Vec::new())));
636    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
637    /// # let (root, _, runtime) = setup_test("persistent_client_Client_get_with_state").unwrap();
638    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
639    ///
640    /// let handle = client.get_with_state(
641    ///     "example.namespace",
642    ///     Params(0),
643    ///     log.clone(),
644    /// );
645    /// assert_eq!(*handle.unwrap_inner(), 0);
646    /// assert_eq!(log.0.lock().unwrap().clone(), vec![Params(0)]);
647    /// ```
648    pub fn get_with_state<S: Send + Sync + Any, K: CacheableWithState<S>>(
649        &self,
650        namespace: impl Into<Namespace>,
651        key: K,
652        state: S,
653    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
654        let namespace = namespace.into();
655        self.generate_result_with_state(namespace, key, state, |key, state| {
656            key.generate_with_state(state)
657        })
658    }
659
660    /// Gets a handle to a cacheable object from the cache, caching failures as well.
661    ///
662    /// Generates the object in the background if needed.
663    ///
664    /// See [`Client::get_with_err`] and [`Client::get_with_state`] for related examples.
665    pub fn get_with_state_and_err<
666        S: Send + Sync + Any,
667        E: Send + Sync + Serialize + DeserializeOwned + Any,
668        K: CacheableWithState<S, Error = E>,
669    >(
670        &self,
671        namespace: impl Into<Namespace>,
672        key: K,
673        state: S,
674    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
675        let namespace = namespace.into();
676        self.generate_with_state(namespace, key, state, |key, state| {
677            key.generate_with_state(state)
678        })
679    }
680
681    /// Sets up the necessary objects to be passed in to [`Client::spawn_handler`].
682    fn setup_generate<K: Serialize, V>(namespace: Namespace, key: K) -> GenerateState<K, V> {
683        GenerateState {
684            handle: CacheHandleInner::default(),
685            namespace,
686            hash: crate::hash(&flexbuffers::to_vec(&key).unwrap()),
687            key,
688        }
689    }
690
691    /// Spawns a new thread to generate the desired value asynchronously.
692    ///
693    /// If the provided handler returns a error, stores an [`Arc`]ed error in the handle.
694    fn spawn_handler<V: Send + Sync + Any>(
695        self,
696        handle: CacheHandleInner<V>,
697        handler: impl FnOnce() -> Result<()> + Send + Any,
698    ) {
699        thread::spawn(move || {
700            if let Err(e) = handler() {
701                tracing::error!("encountered error while executing handler: {}", e,);
702                handle.set(Err(Arc::new(e)));
703            }
704        });
705    }
706
707    /// Deserializes a cached value into a [`Result`] that can be stored in a [`CacheHandle`].
708    fn deserialize_cache_value<V: DeserializeOwned>(data: &[u8]) -> Result<V> {
709        let data = flexbuffers::from_slice(data)?;
710        Ok(data)
711    }
712
713    /// Deserializes a cached value into a containing result with the appropriate error type.
714    fn deserialize_cache_result<V: DeserializeOwned, E>(
715        data: &[u8],
716    ) -> Result<std::result::Result<V, E>> {
717        let data = flexbuffers::from_slice(data)?;
718        Ok(Ok(data))
719    }
720
721    /// Starts sending heartbeats to the server in a new thread .
722    ///
723    /// Returns a sender for telling the spawned thread to stop sending heartbeats and
724    /// a receiver for waiting for heartbeats to terminate.
725    fn start_heartbeats(
726        &self,
727        heartbeat_interval: Duration,
728        send_heartbeat: impl HeartbeatFn,
729    ) -> (Sender<()>, Receiver<()>) {
730        tracing::debug!("starting heartbeats");
731        let (s_heartbeat_stop, r_heartbeat_stop) = channel();
732        let (s_heartbeat_stopped, r_heartbeat_stopped) = channel();
733        let self_clone = self.clone();
734        thread::spawn(move || {
735            loop {
736                match r_heartbeat_stop.recv_timeout(heartbeat_interval) {
737                    Ok(_) | Err(RecvTimeoutError::Disconnected) => {
738                        break;
739                    }
740                    Err(RecvTimeoutError::Timeout) => {
741                        if send_heartbeat(&self_clone).is_err() {
742                            break;
743                        }
744                    }
745                }
746            }
747            let _ = s_heartbeat_stopped.send(());
748        });
749        (s_heartbeat_stop, r_heartbeat_stopped)
750    }
751
752    /// Converts a [`Result<(S, bool)>`] to a [`std::result::Result<S, backoff::Error<Error>>`].
753    ///
754    /// If the `retry` boolean is `true`, returns a [`backoff::Error::Transient`]. If the provided
755    /// result is [`Err`], returns a [`backoff::Error::Permanent`]. Otherwise, returns the entry
756    /// status of type `S`.
757    fn run_backoff_loop<S>(&self, get_status_fn: impl Fn() -> Result<(S, bool)>) -> Result<S> {
758        Ok(backoff::retry(self.inner.poll_backoff.clone(), move || {
759            tracing::debug!("attempting get request to retrieve entry status");
760            get_status_fn()
761                .map_err(backoff::Error::Permanent)
762                .and_then(|(status, retry)| {
763                    if retry {
764                        tracing::debug!("entry is loading, retrying later");
765                        Err(backoff::Error::transient(Error::EntryLoading))
766                    } else {
767                        tracing::debug!("entry status retrieved");
768                        Ok(status)
769                    }
770                })
771        })
772        .map_err(Box::new)?)
773    }
774
775    /// Handles an unassigned entry by generating it locally.
776    fn handle_unassigned<K: Send + Sync + Any, V: Send + Sync + Any>(
777        handle: CacheHandleInner<V>,
778        key: K,
779        generate_fn: impl GenerateFn<K, V>,
780    ) {
781        tracing::debug!("entry is unassigned, generating locally");
782        let v = run_generator(move || generate_fn(&key));
783        handle.set(v);
784    }
785
786    /// Handles an assigned entry by generating it locally and sending heartbeats periodically
787    /// while the generator is running.
788    fn handle_assigned<K: Send + Sync + Any, V: Send + Sync + Any>(
789        &self,
790        key: K,
791        generate_fn: impl GenerateFn<K, V>,
792        heartbeat_interval_ms: u64,
793        send_heartbeat: impl HeartbeatFn,
794    ) -> ArcResult<V> {
795        tracing::debug!("entry has been assigned to the client, generating locally");
796        let (s_heartbeat_stop, r_heartbeat_stopped) =
797            self.start_heartbeats(Duration::from_millis(heartbeat_interval_ms), send_heartbeat);
798        let v = run_generator(move || generate_fn(&key));
799        let _ = s_heartbeat_stop.send(());
800        let _ = r_heartbeat_stopped.recv();
801        tracing::debug!("finished generating, writing value to cache");
802        v
803    }
804
805    /// Connects to a local cache gRPC server.
806    async fn connect_local(&self) -> Result<local_cache_client::LocalCacheClient<Channel>> {
807        let endpoint = Endpoint::from_shared(self.inner.url.clone())?
808            .timeout(self.inner.request_timeout)
809            .connect_timeout(self.inner.connection_timeout);
810        let test = local_cache_client::LocalCacheClient::connect(endpoint).await;
811        Ok(test?)
812    }
813
814    /// Issues a `Get` RPC to a local cache gRPC server.
815    fn get_rpc_local(
816        &self,
817        namespace: String,
818        key: Vec<u8>,
819        assign: bool,
820    ) -> Result<local::get_reply::EntryStatus> {
821        let out: Result<local::GetReply> = self.inner.handle.block_on(async {
822            let mut client = self.connect_local().await?;
823            Ok(client
824                .get(local::GetRequest {
825                    namespace,
826                    key,
827                    assign,
828                })
829                .await
830                .map_err(Box::new)?
831                .into_inner())
832        });
833        Ok(out?.entry_status.unwrap())
834    }
835
836    /// Issues a `Heartbeat` RPC to a local cache gRPC server.
837    fn heartbeat_rpc_local(&self, id: u64) -> Result<()> {
838        self.inner.handle.block_on(async {
839            let mut client = self.connect_local().await?;
840            client
841                .heartbeat(local::HeartbeatRequest { id })
842                .await
843                .map_err(Box::new)?;
844            Ok(())
845        })
846    }
847
848    /// Issues a `Done` RPC to a local cache gRPC server.
849    fn done_rpc_local(&self, id: u64) -> Result<()> {
850        self.inner.handle.block_on(async {
851            let mut client = self.connect_local().await?;
852            client
853                .done(local::DoneRequest { id })
854                .await
855                .map_err(Box::new)?;
856            Ok(())
857        })
858    }
859
860    /// Issues a `Drop` RPC to a local cache gRPC server.
861    fn drop_rpc_local(&self, id: u64) -> Result<()> {
862        self.inner.handle.block_on(async {
863            let mut client = self.connect_local().await?;
864            client
865                .drop(local::DropRequest { id })
866                .await
867                .map_err(Box::new)?;
868            Ok(())
869        })
870    }
871
872    fn write_generated_data_to_disk<V: Serialize>(
873        &self,
874        id: u64,
875        path: String,
876        data: &V,
877    ) -> Result<()> {
878        let path = PathBuf::from(path);
879        if let Some(parent) = path.parent() {
880            fs::create_dir_all(parent)?;
881        }
882
883        let mut f = OpenOptions::new()
884            .read(true)
885            .write(true)
886            .create(true)
887            .truncate(true)
888            .open(&path)?;
889        f.write_all(&flexbuffers::to_vec(data).unwrap())?;
890        self.done_rpc_local(id)?;
891
892        Ok(())
893    }
894
895    /// Writes a generated value to a local cache via the `Set` RPC.
896    fn write_generated_value_local<V: Serialize>(
897        &self,
898        id: u64,
899        path: String,
900        value: &ArcResult<V>,
901    ) -> Result<()> {
902        if let Ok(data) = value {
903            self.write_generated_data_to_disk(id, path, data)?;
904        }
905        Ok(())
906    }
907
908    /// Writes data contained in a generated result to a local cache via the `Set` RPC.
909    ///
910    /// Does not write to the cache if the generated result is an [`Err`].
911    fn write_generated_result_local<V: Serialize, E>(
912        &self,
913        id: u64,
914        path: String,
915        value: &ArcResult<std::result::Result<V, E>>,
916    ) -> Result<()> {
917        if let Ok(Ok(data)) = value {
918            self.write_generated_data_to_disk(id, path, data)?;
919        }
920        Ok(())
921    }
922
923    /// Runs the generate loop for the local cache protocol, checking whether the desired entry is
924    /// loaded and generating it if needed.
925    fn generate_loop_local<K: Send + Sync + Any, V: Send + Sync + Any>(
926        &self,
927        state: GenerateState<K, V>,
928        generate_fn: impl GenerateFn<K, V>,
929        write_generated_value: impl LocalWriteValueFn<V>,
930        deserialize_cache_data: impl DeserializeValueFn<V>,
931    ) -> Result<()> {
932        let GenerateState {
933            handle,
934            namespace,
935            hash,
936            key,
937        } = state;
938
939        let status = self.run_backoff_loop(|| {
940            let status = self.get_rpc_local(namespace.clone().into_inner(), hash.clone(), true)?;
941            let retry = matches!(status, local::get_reply::EntryStatus::Loading(_));
942
943            Ok((status, retry))
944        })?;
945
946        match status {
947            local::get_reply::EntryStatus::Unassigned(_) => {
948                Client::handle_unassigned(handle, key, generate_fn);
949            }
950            local::get_reply::EntryStatus::Assign(local::AssignReply {
951                id,
952                path,
953                heartbeat_interval_ms,
954            }) => {
955                let v = self.handle_assigned(
956                    key,
957                    generate_fn,
958                    heartbeat_interval_ms,
959                    move |client| -> Result<()> { client.heartbeat_rpc_local(id) },
960                );
961                write_generated_value(self, id, path, &v)?;
962                handle.set(v);
963            }
964            local::get_reply::EntryStatus::Loading(_) => unreachable!(),
965            local::get_reply::EntryStatus::Ready(local::ReadyReply { id, path }) => {
966                tracing::debug!("entry is ready, reading from cache");
967                let mut file = std::fs::File::open(path)?;
968                let mut buf = Vec::new();
969                file.read_to_end(&mut buf)?;
970                self.drop_rpc_local(id)?;
971                tracing::debug!("finished reading entry from disk");
972                handle.set(Ok(deserialize_cache_data(&buf)?));
973            }
974        }
975        Ok(())
976    }
977
978    fn generate_inner_local<
979        K: Serialize + Any + Send + Sync,
980        V: Serialize + DeserializeOwned + Send + Sync + Any,
981    >(
982        self,
983        state: GenerateState<K, V>,
984        generate_fn: impl GenerateFn<K, V>,
985    ) {
986        tracing::debug!("generating using local cache API");
987        self.clone().spawn_handler(state.handle.clone(), move || {
988            self.generate_loop_local(
989                state,
990                generate_fn,
991                Client::write_generated_value_local,
992                Client::deserialize_cache_value,
993            )
994        });
995    }
996
997    fn generate_result_inner_local<
998        K: Serialize + Any + Send + Sync,
999        V: Serialize + DeserializeOwned + Send + Sync + Any,
1000        E: Send + Sync + Any,
1001    >(
1002        self,
1003        state: GenerateState<K, std::result::Result<V, E>>,
1004        generate_fn: impl GenerateResultFn<K, V, E>,
1005    ) {
1006        self.clone().spawn_handler(state.handle.clone(), move || {
1007            self.generate_loop_local(
1008                state,
1009                generate_fn,
1010                Client::write_generated_result_local,
1011                Client::deserialize_cache_result,
1012            )
1013        });
1014    }
1015
1016    /// Connects to a remote cache gRPC server.
1017    async fn connect_remote(&self) -> Result<remote_cache_client::RemoteCacheClient<Channel>> {
1018        let endpoint = Endpoint::from_shared(self.inner.url.clone())?
1019            .timeout(self.inner.request_timeout)
1020            .connect_timeout(self.inner.connection_timeout);
1021        Ok(remote_cache_client::RemoteCacheClient::connect(endpoint).await?)
1022    }
1023
1024    /// Issues a `Get` RPC to a remote cache gRPC server.
1025    fn get_rpc_remote(
1026        &self,
1027        namespace: String,
1028        key: Vec<u8>,
1029        assign: bool,
1030    ) -> Result<remote::get_reply::EntryStatus> {
1031        let out: Result<remote::GetReply> = self.inner.handle.block_on(async {
1032            let mut client = self.connect_remote().await?;
1033            Ok(client
1034                .get(remote::GetRequest {
1035                    namespace,
1036                    key,
1037                    assign,
1038                })
1039                .await
1040                .map_err(Box::new)?
1041                .into_inner())
1042        });
1043        Ok(out?.entry_status.unwrap())
1044    }
1045
1046    /// Issues a `Heartbeat` RPC to a remote cache gRPC server.
1047    fn heartbeat_rpc_remote(&self, id: u64) -> Result<()> {
1048        self.inner.handle.block_on(async {
1049            let mut client = self.connect_remote().await?;
1050            client
1051                .heartbeat(remote::HeartbeatRequest { id })
1052                .await
1053                .map_err(Box::new)?;
1054            Ok(())
1055        })
1056    }
1057
1058    /// Issues a `Set` RPC to a remote cache gRPC server.
1059    fn set_rpc_remote(&self, id: u64, value: Vec<u8>) -> Result<()> {
1060        self.inner.handle.block_on(async {
1061            let mut client = self.connect_remote().await?;
1062            client
1063                .set(remote::SetRequest { id, value })
1064                .await
1065                .map_err(Box::new)?;
1066            Ok(())
1067        })
1068    }
1069
1070    /// Writes a generated value to a remote cache via the `Set` RPC.
1071    fn write_generated_value_remote<V: Serialize>(
1072        &self,
1073        id: u64,
1074        value: &ArcResult<V>,
1075    ) -> Result<()> {
1076        if let Ok(data) = value {
1077            self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1078        }
1079        Ok(())
1080    }
1081
1082    /// Writes data contained in a generated result to a remote cache via the `Set` RPC.
1083    ///
1084    /// Does not write to the cache if the generated result is an [`Err`].
1085    fn write_generated_result_remote<V: Serialize, E>(
1086        &self,
1087        id: u64,
1088        value: &ArcResult<std::result::Result<V, E>>,
1089    ) -> Result<()> {
1090        if let Ok(Ok(data)) = value {
1091            self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1092        }
1093        Ok(())
1094    }
1095
1096    /// Runs the generate loop for the remote cache protocol, checking whether the desired entry is
1097    /// loaded and generating it if needed.
1098    fn generate_loop_remote<K: Send + Sync + Any, V: Send + Sync + Any>(
1099        &self,
1100        state: GenerateState<K, V>,
1101        generate_fn: impl GenerateFn<K, V>,
1102        write_generated_value: impl RemoteWriteValueFn<V>,
1103        deserialize_cache_data: impl DeserializeValueFn<V>,
1104    ) -> Result<()> {
1105        let GenerateState {
1106            handle,
1107            namespace,
1108            hash,
1109            key,
1110        } = state;
1111
1112        let status = self.run_backoff_loop(|| {
1113            let status = self.get_rpc_remote(namespace.clone().into_inner(), hash.clone(), true)?;
1114            let retry = matches!(status, remote::get_reply::EntryStatus::Loading(_));
1115
1116            Ok((status, retry))
1117        })?;
1118
1119        match status {
1120            remote::get_reply::EntryStatus::Unassigned(_) => {
1121                Client::handle_unassigned(handle, key, generate_fn);
1122            }
1123            remote::get_reply::EntryStatus::Assign(remote::AssignReply {
1124                id,
1125                heartbeat_interval_ms,
1126            }) => {
1127                let v = self.handle_assigned(
1128                    key,
1129                    generate_fn,
1130                    heartbeat_interval_ms,
1131                    move |client| -> Result<()> { client.heartbeat_rpc_remote(id) },
1132                );
1133                write_generated_value(self, id, &v)?;
1134                handle.set(v);
1135            }
1136            remote::get_reply::EntryStatus::Loading(_) => unreachable!(),
1137            remote::get_reply::EntryStatus::Ready(data) => {
1138                tracing::debug!("entry is ready");
1139                handle.set(Ok(deserialize_cache_data(&data)?));
1140            }
1141        }
1142        Ok(())
1143    }
1144
1145    fn generate_inner_remote<
1146        K: Serialize + Any + Send + Sync,
1147        V: Serialize + DeserializeOwned + Send + Sync + Any,
1148    >(
1149        self,
1150        state: GenerateState<K, V>,
1151        generate_fn: impl GenerateFn<K, V>,
1152    ) {
1153        tracing::debug!("generating using remote cache API");
1154        self.clone().spawn_handler(state.handle.clone(), move || {
1155            self.generate_loop_remote(
1156                state,
1157                generate_fn,
1158                Client::write_generated_value_remote,
1159                Client::deserialize_cache_value,
1160            )
1161        });
1162    }
1163
1164    fn generate_result_inner_remote<
1165        K: Serialize + Any + Send + Sync,
1166        V: Serialize + DeserializeOwned + Send + Sync + Any,
1167        E: Send + Sync + Any,
1168    >(
1169        self,
1170        state: GenerateState<K, std::result::Result<V, E>>,
1171        generate_fn: impl GenerateResultFn<K, V, E>,
1172    ) {
1173        self.clone().spawn_handler(state.handle.clone(), move || {
1174            self.generate_loop_remote(
1175                state,
1176                generate_fn,
1177                Client::write_generated_result_remote,
1178                Client::deserialize_cache_result,
1179            )
1180        });
1181    }
1182}
1183
1184pub(crate) const BUILD_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/build");
1185pub(crate) const TEST_SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_millis(200);
1186pub(crate) const TEST_SERVER_HEARTBEAT_TIMEOUT: Duration = Duration::from_millis(500);
1187
1188pub(crate) fn get_listeners(n: usize) -> Vec<(TcpListener, u16)> {
1189    let mut listeners = Vec::new();
1190
1191    for _ in 0..n {
1192        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1193        let port = listener.local_addr().unwrap().port();
1194        listeners.push((listener, port));
1195    }
1196
1197    listeners
1198}
1199
1200#[doc(hidden)]
1201#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1202pub enum ServerKind {
1203    Local,
1204    Remote,
1205    Both,
1206}
1207
1208impl From<ClientKind> for ServerKind {
1209    fn from(value: ClientKind) -> Self {
1210        match value {
1211            ClientKind::Local => ServerKind::Local,
1212            ClientKind::Remote => ServerKind::Remote,
1213        }
1214    }
1215}
1216
1217pub(crate) fn client_url(port: u16) -> String {
1218    format!("http://127.0.0.1:{port}")
1219}
1220
1221#[doc(hidden)]
1222pub fn create_server_and_clients(
1223    root: PathBuf,
1224    kind: ServerKind,
1225    handle: &Handle,
1226) -> (CacheHandle<Result<()>>, Client, Client) {
1227    let mut listeners = handle.block_on(async {
1228        get_listeners(2)
1229            .into_iter()
1230            .map(|(listener, port)| {
1231                listener.set_nonblocking(true).unwrap();
1232                (tokio::net::TcpListener::from_std(listener).unwrap(), port)
1233            })
1234            .collect::<Vec<_>>()
1235    });
1236    let (local_listener, local_port) = listeners.pop().unwrap();
1237    let (remote_listener, remote_port) = listeners.pop().unwrap();
1238
1239    (
1240        {
1241            let mut builder = Server::builder();
1242
1243            builder = builder
1244                .heartbeat_interval(TEST_SERVER_HEARTBEAT_INTERVAL)
1245                .heartbeat_timeout(TEST_SERVER_HEARTBEAT_TIMEOUT)
1246                .root(root);
1247
1248            let server = match kind {
1249                ServerKind::Local => builder.local_with_incoming(local_listener),
1250                ServerKind::Remote => builder.remote_with_incoming(remote_listener),
1251                ServerKind::Both => builder
1252                    .local_with_incoming(local_listener)
1253                    .remote_with_incoming(remote_listener),
1254            }
1255            .build();
1256
1257            let join_handle = handle.spawn(async move { server.start().await });
1258            let handle_clone = handle.clone();
1259            CacheHandle::new(move || {
1260                let res = handle_clone.block_on(join_handle).unwrap_or_else(|res| {
1261                    if res.is_cancelled() {
1262                        Ok(())
1263                    } else {
1264                        Err(Error::Panic)
1265                    }
1266                });
1267                if let Err(e) = res.as_ref() {
1268                    tracing::error!("server failed to start: {:?}", e);
1269                }
1270                res
1271            })
1272        },
1273        Client::builder()
1274            .kind(ClientKind::Local)
1275            .url(client_url(local_port))
1276            .connection_timeout(Duration::from_secs(3))
1277            .request_timeout(Duration::from_secs(3))
1278            .build(),
1279        Client::builder()
1280            .kind(ClientKind::Remote)
1281            .url(client_url(remote_port))
1282            .connection_timeout(Duration::from_secs(3))
1283            .request_timeout(Duration::from_secs(3))
1284            .build(),
1285    )
1286}
1287
1288pub(crate) fn reset_directory(path: impl AsRef<Path>) -> Result<()> {
1289    let path = path.as_ref();
1290    if path.exists() {
1291        fs::remove_dir_all(path)?;
1292    }
1293    fs::create_dir_all(path)?;
1294    Ok(())
1295}
1296
1297pub(crate) fn create_runtime() -> Runtime {
1298    tokio::runtime::Builder::new_multi_thread()
1299        .worker_threads(1)
1300        .enable_all()
1301        .build()
1302        .unwrap()
1303}
1304
1305#[doc(hidden)]
1306pub fn setup_test(test_name: &str) -> Result<(PathBuf, Arc<Mutex<u64>>, Runtime)> {
1307    let path = PathBuf::from(BUILD_DIR).join(test_name);
1308    reset_directory(&path)?;
1309    Ok((path, Arc::new(Mutex::new(0)), create_runtime()))
1310}