1use std::{
4 any::Any,
5 fs::{self, OpenOptions},
6 io::{Read, Write},
7 net::TcpListener,
8 path::{Path, PathBuf},
9 sync::{
10 Arc, Mutex,
11 mpsc::{Receiver, RecvTimeoutError, Sender, channel},
12 },
13 thread,
14 time::Duration,
15};
16
17use backoff::ExponentialBackoff;
18use serde::{Deserialize, Serialize, de::DeserializeOwned};
19use tokio::runtime::{Handle, Runtime};
20use tonic::transport::{Channel, Endpoint};
21
22use crate::{
23 CacheHandle, CacheHandleInner, Cacheable, CacheableWithState, GenerateFn, GenerateResultFn,
24 GenerateResultWithStateFn, GenerateWithStateFn, Namespace,
25 error::{ArcResult, Error, Result},
26 rpc::{
27 local::{self, local_cache_client},
28 remote::{self, remote_cache_client},
29 },
30 run_generator,
31};
32
33use super::server::Server;
34
35pub const CONNECTION_TIMEOUT_MS_DEFAULT: u64 = 1000;
37
38pub const REQUEST_TIMEOUT_MS_DEFAULT: u64 = 1000;
40
41#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
45pub enum ClientKind {
46 Local,
50 Remote,
54}
55
56#[derive(Debug)]
57struct ClientInner {
58 kind: ClientKind,
59 url: String,
60 poll_backoff: ExponentialBackoff,
61 connection_timeout: Duration,
62 request_timeout: Duration,
63 handle: Handle,
64 #[allow(dead_code)]
66 runtime: Option<Runtime>,
67}
68
69#[derive(Debug, Clone)]
74pub struct Client {
75 inner: Arc<ClientInner>,
76}
77
78#[derive(Default, Clone, Debug)]
80pub struct ClientBuilder {
81 kind: Option<ClientKind>,
82 url: Option<String>,
83 poll_backoff: Option<ExponentialBackoff>,
84 connection_timeout: Option<Duration>,
85 request_timeout: Option<Duration>,
86 handle: Option<Handle>,
87}
88
89struct GenerateState<K, V> {
90 handle: CacheHandleInner<V>,
91 namespace: Namespace,
92 hash: Vec<u8>,
93 key: K,
94}
95
96trait HeartbeatFn: Fn(&Client) -> Result<()> + Send + Any {}
98impl<T: Fn(&Client) -> Result<()> + Send + Any> HeartbeatFn for T {}
99
100trait LocalWriteValueFn<V>:
103 FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any
104{
105}
106impl<V, T: FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any>
107 LocalWriteValueFn<V> for T
108{
109}
110
111trait RemoteWriteValueFn<V>: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any {}
114impl<V, T: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any> RemoteWriteValueFn<V>
115 for T
116{
117}
118
119trait DeserializeValueFn<V>: FnOnce(&[u8]) -> Result<V> + Send + Any {}
122impl<V, T: FnOnce(&[u8]) -> Result<V> + Send + Any> DeserializeValueFn<V> for T {}
123
124impl ClientBuilder {
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 pub fn url(&mut self, url: impl Into<String>) -> &mut Self {
132 self.url = Some(url.into());
133 self
134 }
135
136 pub fn kind(&mut self, kind: ClientKind) -> &mut Self {
138 self.kind = Some(kind);
139 self
140 }
141 pub fn local(url: impl Into<String>) -> Self {
144 let mut builder = Self::new();
145 builder.kind(ClientKind::Local).url(url);
146 builder
147 }
148
149 pub fn remote(url: impl Into<String>) -> Self {
152 let mut builder = Self::new();
153 builder.kind(ClientKind::Remote).url(url);
154 builder
155 }
156
157 pub fn poll_backoff(&mut self, backoff: ExponentialBackoff) -> &mut Self {
162 self.poll_backoff = Some(backoff);
163 self
164 }
165
166 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
170 self.connection_timeout = Some(timeout);
171 self
172 }
173
174 pub fn request_timeout(&mut self, timeout: Duration) -> &mut Self {
178 self.request_timeout = Some(timeout);
179 self
180 }
181
182 pub fn runtime_handle(&mut self, handle: Handle) -> &mut Self {
187 self.handle = Some(handle);
188 self
189 }
190
191 pub fn build(&mut self) -> Client {
193 let (handle, runtime) = match self.handle.clone() {
194 Some(handle) => (handle, None),
195 None => {
196 let runtime = tokio::runtime::Builder::new_multi_thread()
197 .worker_threads(1)
198 .enable_all()
199 .build()
200 .unwrap();
201 (runtime.handle().clone(), Some(runtime))
202 }
203 };
204 Client {
205 inner: Arc::new(ClientInner {
206 kind: self.kind.expect("must specify client kind"),
207 url: self.url.clone().expect("must specify server URL"),
208 poll_backoff: self.poll_backoff.clone().unwrap_or_default(),
209 connection_timeout: self
210 .connection_timeout
211 .unwrap_or(Duration::from_millis(CONNECTION_TIMEOUT_MS_DEFAULT)),
212 request_timeout: self
213 .request_timeout
214 .unwrap_or(Duration::from_millis(REQUEST_TIMEOUT_MS_DEFAULT)),
215 handle,
216 runtime,
217 }),
218 }
219 }
220}
221
222impl Client {
223 pub fn with_default_config(kind: ClientKind, url: impl Into<String>) -> Self {
225 Self::builder().kind(kind).url(url).build()
226 }
227
228 pub fn builder() -> ClientBuilder {
230 ClientBuilder::new()
231 }
232
233 pub fn local(url: impl Into<String>) -> ClientBuilder {
237 ClientBuilder::local(url)
238 }
239
240 pub fn remote(url: impl Into<String>) -> ClientBuilder {
244 ClientBuilder::remote(url)
245 }
246
247 pub fn generate<
274 K: Serialize + Any + Send + Sync,
275 V: Serialize + DeserializeOwned + Send + Sync + Any,
276 >(
277 &self,
278 namespace: impl Into<Namespace>,
279 key: K,
280 generate_fn: impl GenerateFn<K, V>,
281 ) -> CacheHandle<V> {
282 CacheHandle::from_inner(Arc::new(self.generate_inner(namespace, key, generate_fn)))
283 }
284
285 pub(crate) fn generate_inner<
286 K: Serialize + Any + Send + Sync,
287 V: Serialize + DeserializeOwned + Send + Sync + Any,
288 >(
289 &self,
290 namespace: impl Into<Namespace>,
291 key: K,
292 generate_fn: impl GenerateFn<K, V>,
293 ) -> CacheHandleInner<V> {
294 let namespace = namespace.into();
295 let state = Client::setup_generate(namespace, key);
296 let handle = state.handle.clone();
297
298 match self.inner.kind {
299 ClientKind::Local => self.clone().generate_inner_local(state, generate_fn),
300 ClientKind::Remote => self.clone().generate_inner_remote(state, generate_fn),
301 }
302
303 handle
304 }
305
306 pub fn generate_with_state<
343 K: Serialize + Send + Sync + Any,
344 V: Serialize + DeserializeOwned + Send + Sync + Any,
345 S: Send + Sync + Any,
346 >(
347 &self,
348 namespace: impl Into<Namespace>,
349 key: K,
350 state: S,
351 generate_fn: impl GenerateWithStateFn<K, S, V>,
352 ) -> CacheHandle<V> {
353 let namespace = namespace.into();
354 self.generate(namespace, key, move |k| generate_fn(k, state))
355 }
356
357 pub fn generate_result<
392 K: Serialize + Any + Send + Sync,
393 V: Serialize + DeserializeOwned + Send + Sync + Any,
394 E: Send + Sync + Any,
395 >(
396 &self,
397 namespace: impl Into<Namespace>,
398 key: K,
399 generate_fn: impl GenerateResultFn<K, V, E>,
400 ) -> CacheHandle<std::result::Result<V, E>> {
401 CacheHandle::from_inner(Arc::new(self.generate_result_inner(
402 namespace,
403 key,
404 generate_fn,
405 )))
406 }
407
408 pub(crate) fn generate_result_inner<
409 K: Serialize + Any + Send + Sync,
410 V: Serialize + DeserializeOwned + Send + Sync + Any,
411 E: Send + Sync + Any,
412 >(
413 &self,
414 namespace: impl Into<Namespace>,
415 key: K,
416 generate_fn: impl GenerateResultFn<K, V, E>,
417 ) -> CacheHandleInner<std::result::Result<V, E>> {
418 let namespace = namespace.into();
419 let state = Client::setup_generate(namespace, key);
420 let handle = state.handle.clone();
421
422 match self.inner.kind {
423 ClientKind::Local => {
424 self.clone().generate_result_inner_local(state, generate_fn);
425 }
426 ClientKind::Remote => {
427 self.clone()
428 .generate_result_inner_remote(state, generate_fn);
429 }
430 }
431
432 handle
433 }
434
435 pub fn generate_result_with_state<
481 K: Serialize + Send + Sync + Any,
482 V: Serialize + DeserializeOwned + Send + Sync + Any,
483 E: Send + Sync + Any,
484 S: Send + Sync + Any,
485 >(
486 &self,
487 namespace: impl Into<Namespace>,
488 key: K,
489 state: S,
490 generate_fn: impl GenerateResultWithStateFn<K, S, V, E>,
491 ) -> CacheHandle<std::result::Result<V, E>> {
492 let namespace = namespace.into();
493 self.generate_result(namespace, key, move |k| generate_fn(k, state))
494 }
495
496 pub fn get<K: Cacheable>(
538 &self,
539 namespace: impl Into<Namespace>,
540 key: K,
541 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
542 let namespace = namespace.into();
543 self.generate_result(namespace, key, |key| key.generate())
544 }
545
546 pub fn get_with_err<
589 E: Send + Sync + Serialize + DeserializeOwned + Any,
590 K: Cacheable<Error = E>,
591 >(
592 &self,
593 namespace: impl Into<Namespace>,
594 key: K,
595 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
596 let namespace = namespace.into();
597 self.generate(namespace, key, |key| key.generate())
598 }
599
600 pub fn get_with_state<S: Send + Sync + Any, K: CacheableWithState<S>>(
649 &self,
650 namespace: impl Into<Namespace>,
651 key: K,
652 state: S,
653 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
654 let namespace = namespace.into();
655 self.generate_result_with_state(namespace, key, state, |key, state| {
656 key.generate_with_state(state)
657 })
658 }
659
660 pub fn get_with_state_and_err<
666 S: Send + Sync + Any,
667 E: Send + Sync + Serialize + DeserializeOwned + Any,
668 K: CacheableWithState<S, Error = E>,
669 >(
670 &self,
671 namespace: impl Into<Namespace>,
672 key: K,
673 state: S,
674 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
675 let namespace = namespace.into();
676 self.generate_with_state(namespace, key, state, |key, state| {
677 key.generate_with_state(state)
678 })
679 }
680
681 fn setup_generate<K: Serialize, V>(namespace: Namespace, key: K) -> GenerateState<K, V> {
683 GenerateState {
684 handle: CacheHandleInner::default(),
685 namespace,
686 hash: crate::hash(&flexbuffers::to_vec(&key).unwrap()),
687 key,
688 }
689 }
690
691 fn spawn_handler<V: Send + Sync + Any>(
695 self,
696 handle: CacheHandleInner<V>,
697 handler: impl FnOnce() -> Result<()> + Send + Any,
698 ) {
699 thread::spawn(move || {
700 if let Err(e) = handler() {
701 tracing::error!("encountered error while executing handler: {}", e,);
702 handle.set(Err(Arc::new(e)));
703 }
704 });
705 }
706
707 fn deserialize_cache_value<V: DeserializeOwned>(data: &[u8]) -> Result<V> {
709 let data = flexbuffers::from_slice(data)?;
710 Ok(data)
711 }
712
713 fn deserialize_cache_result<V: DeserializeOwned, E>(
715 data: &[u8],
716 ) -> Result<std::result::Result<V, E>> {
717 let data = flexbuffers::from_slice(data)?;
718 Ok(Ok(data))
719 }
720
721 fn start_heartbeats(
726 &self,
727 heartbeat_interval: Duration,
728 send_heartbeat: impl HeartbeatFn,
729 ) -> (Sender<()>, Receiver<()>) {
730 tracing::debug!("starting heartbeats");
731 let (s_heartbeat_stop, r_heartbeat_stop) = channel();
732 let (s_heartbeat_stopped, r_heartbeat_stopped) = channel();
733 let self_clone = self.clone();
734 thread::spawn(move || {
735 loop {
736 match r_heartbeat_stop.recv_timeout(heartbeat_interval) {
737 Ok(_) | Err(RecvTimeoutError::Disconnected) => {
738 break;
739 }
740 Err(RecvTimeoutError::Timeout) => {
741 if send_heartbeat(&self_clone).is_err() {
742 break;
743 }
744 }
745 }
746 }
747 let _ = s_heartbeat_stopped.send(());
748 });
749 (s_heartbeat_stop, r_heartbeat_stopped)
750 }
751
752 fn run_backoff_loop<S>(&self, get_status_fn: impl Fn() -> Result<(S, bool)>) -> Result<S> {
758 Ok(backoff::retry(self.inner.poll_backoff.clone(), move || {
759 tracing::debug!("attempting get request to retrieve entry status");
760 get_status_fn()
761 .map_err(backoff::Error::Permanent)
762 .and_then(|(status, retry)| {
763 if retry {
764 tracing::debug!("entry is loading, retrying later");
765 Err(backoff::Error::transient(Error::EntryLoading))
766 } else {
767 tracing::debug!("entry status retrieved");
768 Ok(status)
769 }
770 })
771 })
772 .map_err(Box::new)?)
773 }
774
775 fn handle_unassigned<K: Send + Sync + Any, V: Send + Sync + Any>(
777 handle: CacheHandleInner<V>,
778 key: K,
779 generate_fn: impl GenerateFn<K, V>,
780 ) {
781 tracing::debug!("entry is unassigned, generating locally");
782 let v = run_generator(move || generate_fn(&key));
783 handle.set(v);
784 }
785
786 fn handle_assigned<K: Send + Sync + Any, V: Send + Sync + Any>(
789 &self,
790 key: K,
791 generate_fn: impl GenerateFn<K, V>,
792 heartbeat_interval_ms: u64,
793 send_heartbeat: impl HeartbeatFn,
794 ) -> ArcResult<V> {
795 tracing::debug!("entry has been assigned to the client, generating locally");
796 let (s_heartbeat_stop, r_heartbeat_stopped) =
797 self.start_heartbeats(Duration::from_millis(heartbeat_interval_ms), send_heartbeat);
798 let v = run_generator(move || generate_fn(&key));
799 let _ = s_heartbeat_stop.send(());
800 let _ = r_heartbeat_stopped.recv();
801 tracing::debug!("finished generating, writing value to cache");
802 v
803 }
804
805 async fn connect_local(&self) -> Result<local_cache_client::LocalCacheClient<Channel>> {
807 let endpoint = Endpoint::from_shared(self.inner.url.clone())?
808 .timeout(self.inner.request_timeout)
809 .connect_timeout(self.inner.connection_timeout);
810 let test = local_cache_client::LocalCacheClient::connect(endpoint).await;
811 Ok(test?)
812 }
813
814 fn get_rpc_local(
816 &self,
817 namespace: String,
818 key: Vec<u8>,
819 assign: bool,
820 ) -> Result<local::get_reply::EntryStatus> {
821 let out: Result<local::GetReply> = self.inner.handle.block_on(async {
822 let mut client = self.connect_local().await?;
823 Ok(client
824 .get(local::GetRequest {
825 namespace,
826 key,
827 assign,
828 })
829 .await
830 .map_err(Box::new)?
831 .into_inner())
832 });
833 Ok(out?.entry_status.unwrap())
834 }
835
836 fn heartbeat_rpc_local(&self, id: u64) -> Result<()> {
838 self.inner.handle.block_on(async {
839 let mut client = self.connect_local().await?;
840 client
841 .heartbeat(local::HeartbeatRequest { id })
842 .await
843 .map_err(Box::new)?;
844 Ok(())
845 })
846 }
847
848 fn done_rpc_local(&self, id: u64) -> Result<()> {
850 self.inner.handle.block_on(async {
851 let mut client = self.connect_local().await?;
852 client
853 .done(local::DoneRequest { id })
854 .await
855 .map_err(Box::new)?;
856 Ok(())
857 })
858 }
859
860 fn drop_rpc_local(&self, id: u64) -> Result<()> {
862 self.inner.handle.block_on(async {
863 let mut client = self.connect_local().await?;
864 client
865 .drop(local::DropRequest { id })
866 .await
867 .map_err(Box::new)?;
868 Ok(())
869 })
870 }
871
872 fn write_generated_data_to_disk<V: Serialize>(
873 &self,
874 id: u64,
875 path: String,
876 data: &V,
877 ) -> Result<()> {
878 let path = PathBuf::from(path);
879 if let Some(parent) = path.parent() {
880 fs::create_dir_all(parent)?;
881 }
882
883 let mut f = OpenOptions::new()
884 .read(true)
885 .write(true)
886 .create(true)
887 .truncate(true)
888 .open(&path)?;
889 f.write_all(&flexbuffers::to_vec(data).unwrap())?;
890 self.done_rpc_local(id)?;
891
892 Ok(())
893 }
894
895 fn write_generated_value_local<V: Serialize>(
897 &self,
898 id: u64,
899 path: String,
900 value: &ArcResult<V>,
901 ) -> Result<()> {
902 if let Ok(data) = value {
903 self.write_generated_data_to_disk(id, path, data)?;
904 }
905 Ok(())
906 }
907
908 fn write_generated_result_local<V: Serialize, E>(
912 &self,
913 id: u64,
914 path: String,
915 value: &ArcResult<std::result::Result<V, E>>,
916 ) -> Result<()> {
917 if let Ok(Ok(data)) = value {
918 self.write_generated_data_to_disk(id, path, data)?;
919 }
920 Ok(())
921 }
922
923 fn generate_loop_local<K: Send + Sync + Any, V: Send + Sync + Any>(
926 &self,
927 state: GenerateState<K, V>,
928 generate_fn: impl GenerateFn<K, V>,
929 write_generated_value: impl LocalWriteValueFn<V>,
930 deserialize_cache_data: impl DeserializeValueFn<V>,
931 ) -> Result<()> {
932 let GenerateState {
933 handle,
934 namespace,
935 hash,
936 key,
937 } = state;
938
939 let status = self.run_backoff_loop(|| {
940 let status = self.get_rpc_local(namespace.clone().into_inner(), hash.clone(), true)?;
941 let retry = matches!(status, local::get_reply::EntryStatus::Loading(_));
942
943 Ok((status, retry))
944 })?;
945
946 match status {
947 local::get_reply::EntryStatus::Unassigned(_) => {
948 Client::handle_unassigned(handle, key, generate_fn);
949 }
950 local::get_reply::EntryStatus::Assign(local::AssignReply {
951 id,
952 path,
953 heartbeat_interval_ms,
954 }) => {
955 let v = self.handle_assigned(
956 key,
957 generate_fn,
958 heartbeat_interval_ms,
959 move |client| -> Result<()> { client.heartbeat_rpc_local(id) },
960 );
961 write_generated_value(self, id, path, &v)?;
962 handle.set(v);
963 }
964 local::get_reply::EntryStatus::Loading(_) => unreachable!(),
965 local::get_reply::EntryStatus::Ready(local::ReadyReply { id, path }) => {
966 tracing::debug!("entry is ready, reading from cache");
967 let mut file = std::fs::File::open(path)?;
968 let mut buf = Vec::new();
969 file.read_to_end(&mut buf)?;
970 self.drop_rpc_local(id)?;
971 tracing::debug!("finished reading entry from disk");
972 handle.set(Ok(deserialize_cache_data(&buf)?));
973 }
974 }
975 Ok(())
976 }
977
978 fn generate_inner_local<
979 K: Serialize + Any + Send + Sync,
980 V: Serialize + DeserializeOwned + Send + Sync + Any,
981 >(
982 self,
983 state: GenerateState<K, V>,
984 generate_fn: impl GenerateFn<K, V>,
985 ) {
986 tracing::debug!("generating using local cache API");
987 self.clone().spawn_handler(state.handle.clone(), move || {
988 self.generate_loop_local(
989 state,
990 generate_fn,
991 Client::write_generated_value_local,
992 Client::deserialize_cache_value,
993 )
994 });
995 }
996
997 fn generate_result_inner_local<
998 K: Serialize + Any + Send + Sync,
999 V: Serialize + DeserializeOwned + Send + Sync + Any,
1000 E: Send + Sync + Any,
1001 >(
1002 self,
1003 state: GenerateState<K, std::result::Result<V, E>>,
1004 generate_fn: impl GenerateResultFn<K, V, E>,
1005 ) {
1006 self.clone().spawn_handler(state.handle.clone(), move || {
1007 self.generate_loop_local(
1008 state,
1009 generate_fn,
1010 Client::write_generated_result_local,
1011 Client::deserialize_cache_result,
1012 )
1013 });
1014 }
1015
1016 async fn connect_remote(&self) -> Result<remote_cache_client::RemoteCacheClient<Channel>> {
1018 let endpoint = Endpoint::from_shared(self.inner.url.clone())?
1019 .timeout(self.inner.request_timeout)
1020 .connect_timeout(self.inner.connection_timeout);
1021 Ok(remote_cache_client::RemoteCacheClient::connect(endpoint).await?)
1022 }
1023
1024 fn get_rpc_remote(
1026 &self,
1027 namespace: String,
1028 key: Vec<u8>,
1029 assign: bool,
1030 ) -> Result<remote::get_reply::EntryStatus> {
1031 let out: Result<remote::GetReply> = self.inner.handle.block_on(async {
1032 let mut client = self.connect_remote().await?;
1033 Ok(client
1034 .get(remote::GetRequest {
1035 namespace,
1036 key,
1037 assign,
1038 })
1039 .await
1040 .map_err(Box::new)?
1041 .into_inner())
1042 });
1043 Ok(out?.entry_status.unwrap())
1044 }
1045
1046 fn heartbeat_rpc_remote(&self, id: u64) -> Result<()> {
1048 self.inner.handle.block_on(async {
1049 let mut client = self.connect_remote().await?;
1050 client
1051 .heartbeat(remote::HeartbeatRequest { id })
1052 .await
1053 .map_err(Box::new)?;
1054 Ok(())
1055 })
1056 }
1057
1058 fn set_rpc_remote(&self, id: u64, value: Vec<u8>) -> Result<()> {
1060 self.inner.handle.block_on(async {
1061 let mut client = self.connect_remote().await?;
1062 client
1063 .set(remote::SetRequest { id, value })
1064 .await
1065 .map_err(Box::new)?;
1066 Ok(())
1067 })
1068 }
1069
1070 fn write_generated_value_remote<V: Serialize>(
1072 &self,
1073 id: u64,
1074 value: &ArcResult<V>,
1075 ) -> Result<()> {
1076 if let Ok(data) = value {
1077 self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1078 }
1079 Ok(())
1080 }
1081
1082 fn write_generated_result_remote<V: Serialize, E>(
1086 &self,
1087 id: u64,
1088 value: &ArcResult<std::result::Result<V, E>>,
1089 ) -> Result<()> {
1090 if let Ok(Ok(data)) = value {
1091 self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1092 }
1093 Ok(())
1094 }
1095
1096 fn generate_loop_remote<K: Send + Sync + Any, V: Send + Sync + Any>(
1099 &self,
1100 state: GenerateState<K, V>,
1101 generate_fn: impl GenerateFn<K, V>,
1102 write_generated_value: impl RemoteWriteValueFn<V>,
1103 deserialize_cache_data: impl DeserializeValueFn<V>,
1104 ) -> Result<()> {
1105 let GenerateState {
1106 handle,
1107 namespace,
1108 hash,
1109 key,
1110 } = state;
1111
1112 let status = self.run_backoff_loop(|| {
1113 let status = self.get_rpc_remote(namespace.clone().into_inner(), hash.clone(), true)?;
1114 let retry = matches!(status, remote::get_reply::EntryStatus::Loading(_));
1115
1116 Ok((status, retry))
1117 })?;
1118
1119 match status {
1120 remote::get_reply::EntryStatus::Unassigned(_) => {
1121 Client::handle_unassigned(handle, key, generate_fn);
1122 }
1123 remote::get_reply::EntryStatus::Assign(remote::AssignReply {
1124 id,
1125 heartbeat_interval_ms,
1126 }) => {
1127 let v = self.handle_assigned(
1128 key,
1129 generate_fn,
1130 heartbeat_interval_ms,
1131 move |client| -> Result<()> { client.heartbeat_rpc_remote(id) },
1132 );
1133 write_generated_value(self, id, &v)?;
1134 handle.set(v);
1135 }
1136 remote::get_reply::EntryStatus::Loading(_) => unreachable!(),
1137 remote::get_reply::EntryStatus::Ready(data) => {
1138 tracing::debug!("entry is ready");
1139 handle.set(Ok(deserialize_cache_data(&data)?));
1140 }
1141 }
1142 Ok(())
1143 }
1144
1145 fn generate_inner_remote<
1146 K: Serialize + Any + Send + Sync,
1147 V: Serialize + DeserializeOwned + Send + Sync + Any,
1148 >(
1149 self,
1150 state: GenerateState<K, V>,
1151 generate_fn: impl GenerateFn<K, V>,
1152 ) {
1153 tracing::debug!("generating using remote cache API");
1154 self.clone().spawn_handler(state.handle.clone(), move || {
1155 self.generate_loop_remote(
1156 state,
1157 generate_fn,
1158 Client::write_generated_value_remote,
1159 Client::deserialize_cache_value,
1160 )
1161 });
1162 }
1163
1164 fn generate_result_inner_remote<
1165 K: Serialize + Any + Send + Sync,
1166 V: Serialize + DeserializeOwned + Send + Sync + Any,
1167 E: Send + Sync + Any,
1168 >(
1169 self,
1170 state: GenerateState<K, std::result::Result<V, E>>,
1171 generate_fn: impl GenerateResultFn<K, V, E>,
1172 ) {
1173 self.clone().spawn_handler(state.handle.clone(), move || {
1174 self.generate_loop_remote(
1175 state,
1176 generate_fn,
1177 Client::write_generated_result_remote,
1178 Client::deserialize_cache_result,
1179 )
1180 });
1181 }
1182}
1183
1184pub(crate) const BUILD_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/build");
1185pub(crate) const TEST_SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_millis(200);
1186pub(crate) const TEST_SERVER_HEARTBEAT_TIMEOUT: Duration = Duration::from_millis(500);
1187
1188pub(crate) fn get_listeners(n: usize) -> Vec<(TcpListener, u16)> {
1189 let mut listeners = Vec::new();
1190
1191 for _ in 0..n {
1192 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1193 let port = listener.local_addr().unwrap().port();
1194 listeners.push((listener, port));
1195 }
1196
1197 listeners
1198}
1199
1200#[doc(hidden)]
1201#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1202pub enum ServerKind {
1203 Local,
1204 Remote,
1205 Both,
1206}
1207
1208impl From<ClientKind> for ServerKind {
1209 fn from(value: ClientKind) -> Self {
1210 match value {
1211 ClientKind::Local => ServerKind::Local,
1212 ClientKind::Remote => ServerKind::Remote,
1213 }
1214 }
1215}
1216
1217pub(crate) fn client_url(port: u16) -> String {
1218 format!("http://127.0.0.1:{port}")
1219}
1220
1221#[doc(hidden)]
1222pub fn create_server_and_clients(
1223 root: PathBuf,
1224 kind: ServerKind,
1225 handle: &Handle,
1226) -> (CacheHandle<Result<()>>, Client, Client) {
1227 let mut listeners = handle.block_on(async {
1228 get_listeners(2)
1229 .into_iter()
1230 .map(|(listener, port)| {
1231 listener.set_nonblocking(true).unwrap();
1232 (tokio::net::TcpListener::from_std(listener).unwrap(), port)
1233 })
1234 .collect::<Vec<_>>()
1235 });
1236 let (local_listener, local_port) = listeners.pop().unwrap();
1237 let (remote_listener, remote_port) = listeners.pop().unwrap();
1238
1239 (
1240 {
1241 let mut builder = Server::builder();
1242
1243 builder = builder
1244 .heartbeat_interval(TEST_SERVER_HEARTBEAT_INTERVAL)
1245 .heartbeat_timeout(TEST_SERVER_HEARTBEAT_TIMEOUT)
1246 .root(root);
1247
1248 let server = match kind {
1249 ServerKind::Local => builder.local_with_incoming(local_listener),
1250 ServerKind::Remote => builder.remote_with_incoming(remote_listener),
1251 ServerKind::Both => builder
1252 .local_with_incoming(local_listener)
1253 .remote_with_incoming(remote_listener),
1254 }
1255 .build();
1256
1257 let join_handle = handle.spawn(async move { server.start().await });
1258 let handle_clone = handle.clone();
1259 CacheHandle::new(move || {
1260 let res = handle_clone.block_on(join_handle).unwrap_or_else(|res| {
1261 if res.is_cancelled() {
1262 Ok(())
1263 } else {
1264 Err(Error::Panic)
1265 }
1266 });
1267 if let Err(e) = res.as_ref() {
1268 tracing::error!("server failed to start: {:?}", e);
1269 }
1270 res
1271 })
1272 },
1273 Client::builder()
1274 .kind(ClientKind::Local)
1275 .url(client_url(local_port))
1276 .connection_timeout(Duration::from_secs(3))
1277 .request_timeout(Duration::from_secs(3))
1278 .build(),
1279 Client::builder()
1280 .kind(ClientKind::Remote)
1281 .url(client_url(remote_port))
1282 .connection_timeout(Duration::from_secs(3))
1283 .request_timeout(Duration::from_secs(3))
1284 .build(),
1285 )
1286}
1287
1288pub(crate) fn reset_directory(path: impl AsRef<Path>) -> Result<()> {
1289 let path = path.as_ref();
1290 if path.exists() {
1291 fs::remove_dir_all(path)?;
1292 }
1293 fs::create_dir_all(path)?;
1294 Ok(())
1295}
1296
1297pub(crate) fn create_runtime() -> Runtime {
1298 tokio::runtime::Builder::new_multi_thread()
1299 .worker_threads(1)
1300 .enable_all()
1301 .build()
1302 .unwrap()
1303}
1304
1305#[doc(hidden)]
1306pub fn setup_test(test_name: &str) -> Result<(PathBuf, Arc<Mutex<u64>>, Runtime)> {
1307 let path = PathBuf::from(BUILD_DIR).join(test_name);
1308 reset_directory(&path)?;
1309 Ok((path, Arc::new(Mutex::new(0)), create_runtime()))
1310}