cache/persistent/
server.rs

1//! A persistent cache gRPC server.
2
3use std::collections::HashMap;
4use std::collections::hash_map::Entry;
5use std::path::Path;
6use std::sync::Arc;
7use std::time::Duration;
8use std::{net::SocketAddr, path::PathBuf};
9
10use fs4::tokio::AsyncFileExt;
11use path_absolutize::Absolutize;
12use serde::{Deserialize, Serialize};
13use tokio::fs::{self, File, OpenOptions};
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpListener;
16use tokio::sync::Mutex;
17use tokio::time::Instant;
18use tokio_rusqlite::Connection;
19use tonic::Response;
20
21use crate::Namespace;
22use crate::error::Result;
23use crate::rpc::local::{
24    self,
25    local_cache_server::{LocalCache, LocalCacheServer},
26};
27use crate::rpc::remote::{
28    self,
29    remote_cache_server::{RemoteCache, RemoteCacheServer},
30};
31
32/// The name of the config manifest TOML file.
33pub const CONFIG_MANIFEST_NAME: &str = "Cache.toml";
34
35/// The name of the main manifest database.
36pub const MANIFEST_DB_NAME: &str = "cache.sqlite";
37
38/// The expected interval between heartbeats.
39pub const HEARTBEAT_INTERVAL_SECS_DEFAULT: u64 = 2;
40
41/// The timeout before an assigned task is assumed to have failed.
42pub const HEARTBEAT_TIMEOUT_SECS_DEFAULT: u64 = HEARTBEAT_INTERVAL_SECS_DEFAULT + 2;
43
44const CREATE_MANIFEST_TABLE_STMT: &str = r#"
45    CREATE TABLE IF NOT EXISTS manifest (
46        namespace STRING, 
47        key BLOB NOT NULL,
48        status INTEGER, 
49        PRIMARY KEY (namespace, key)
50    );
51"#;
52
53const READ_MANIFEST_STMT: &str = r#"
54    SELECT namespace, key, status FROM manifest;
55"#;
56
57const DELETE_ENTRIES_WITH_STATUS_STMT: &str = r#"
58    DELETE FROM manifest WHERE status = ?;
59"#;
60
61const INSERT_STATUS_STMT: &str = r#"
62    INSERT INTO manifest (namespace, key, status) VALUES (?, ?, ?);
63"#;
64
65const UPDATE_STATUS_STMT: &str = r#"
66    UPDATE manifest SET status = ? WHERE namespace = ? AND key = ?;
67"#;
68
69const DELETE_STATUS_STMT: &str = r#"
70    DELETE FROM manifest WHERE namespace = ? AND key = ?;
71"#;
72
73/// A gRPC cache server.
74#[derive(Debug)]
75pub struct Server {
76    root: Arc<PathBuf>,
77    local: Option<TcpListener>,
78    remote: Option<TcpListener>,
79    heartbeat_interval: Duration,
80    heartbeat_timeout: Duration,
81}
82
83/// A builder for a gRPC cache server.
84#[derive(Default, Debug)]
85pub struct ServerBuilder {
86    root: Option<Arc<PathBuf>>,
87    local: Option<TcpListener>,
88    remote: Option<TcpListener>,
89    heartbeat_interval: Option<Duration>,
90    heartbeat_timeout: Option<Duration>,
91}
92
93#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
94pub(crate) struct ConfigManifest {
95    pub(crate) local_addr: Option<SocketAddr>,
96    pub(crate) remote_addr: Option<SocketAddr>,
97    pub(crate) heartbeat_interval: Duration,
98    pub(crate) heartbeat_timeout: Duration,
99}
100
101impl ServerBuilder {
102    /// Creates a new [`ServerBuilder`].
103    pub fn new() -> Self {
104        Self::default()
105    }
106
107    /// Sets the root directory of the cache server.
108    pub fn root(mut self, path: PathBuf) -> Self {
109        self.root = Some(Arc::new(path));
110        self
111    }
112
113    /// Configures the local cache gRPC server.
114    ///
115    /// Returns an error if the provided address cannot be bound.
116    pub async fn local(mut self, addr: SocketAddr) -> std::io::Result<Self> {
117        self.local = Some(TcpListener::bind(addr).await?);
118        Ok(self)
119    }
120
121    /// Configures the remote cache gRPC server.
122    ///
123    /// Returns an error if the provided address cannot be bound.
124    pub async fn remote(mut self, addr: SocketAddr) -> std::io::Result<Self> {
125        self.remote = Some(TcpListener::bind(addr).await?);
126        Ok(self)
127    }
128
129    /// Configures the local cache gRPC server to use the provided [`TcpListener`].
130    pub fn local_with_incoming(mut self, incoming: TcpListener) -> Self {
131        self.local = Some(incoming);
132        self
133    }
134
135    /// Configures the remote cache gRPC server to use the provided [`TcpListener`].
136    pub fn remote_with_incoming(mut self, incoming: TcpListener) -> Self {
137        self.remote = Some(incoming);
138        self
139    }
140
141    /// Sets the expected interval between hearbeats.
142    ///
143    /// Defaults to [`HEARTBEAT_INTERVAL_SECS_DEFAULT`].
144    pub fn heartbeat_interval(mut self, duration: Duration) -> Self {
145        self.heartbeat_interval = Some(duration);
146        self
147    }
148
149    /// Sets the timeout before an assigned task is marked for reassignment.
150    ///
151    /// Defaults to [`HEARTBEAT_TIMEOUT_SECS_DEFAULT`].
152    pub fn heartbeat_timeout(mut self, duration: Duration) -> Self {
153        self.heartbeat_timeout = Some(duration);
154        self
155    }
156
157    /// Builds a [`Server`] from the configured options.
158    pub fn build(self) -> Server {
159        let server = Server {
160            root: self.root.clone().unwrap(),
161            local: self.local,
162            remote: self.remote,
163            heartbeat_interval: self
164                .heartbeat_interval
165                .unwrap_or(Duration::from_secs(HEARTBEAT_INTERVAL_SECS_DEFAULT)),
166            heartbeat_timeout: self
167                .heartbeat_timeout
168                .unwrap_or(Duration::from_secs(HEARTBEAT_TIMEOUT_SECS_DEFAULT)),
169        };
170
171        assert!(
172            server.heartbeat_interval < server.heartbeat_timeout,
173            "heartbeat interval must be less than the heartbeat interval"
174        );
175
176        assert_eq!(
177            server.heartbeat_interval.subsec_micros() % 1000,
178            0,
179            "heartbeat interval cannot have finer than millisecond resolution"
180        );
181
182        server
183    }
184}
185
186impl Server {
187    /// Creates a new [`ServerBuilder`] object.
188    pub fn builder() -> ServerBuilder {
189        ServerBuilder::new()
190    }
191
192    /// Starts the gRPC server, listening on the configured address.
193    pub async fn start(self) -> Result<()> {
194        if let (None, None) = (&self.local, &self.remote) {
195            tracing::warn!("no local or remote listener specified so no server is being run");
196            return Ok(());
197        }
198
199        // Write configuration options to the config manifest.
200        let mut config_manifest = OpenOptions::new()
201            .read(true)
202            .write(true)
203            .create(true)
204            .truncate(true)
205            .open(self.root.join(CONFIG_MANIFEST_NAME))
206            .await?;
207        config_manifest.try_lock_exclusive()?;
208        config_manifest
209            .write_all(
210                &toml::to_string(&ConfigManifest {
211                    local_addr: self
212                        .local
213                        .as_ref()
214                        .map(|value| value.local_addr())
215                        .map_or(Ok(None), |v| v.map(Some))?,
216                    remote_addr: self
217                        .remote
218                        .as_ref()
219                        .map(|value| value.local_addr())
220                        .map_or(Ok(None), |v| v.map(Some))?,
221                    heartbeat_interval: self.heartbeat_interval,
222                    heartbeat_timeout: self.heartbeat_timeout,
223                })
224                .unwrap()
225                .into_bytes(),
226            )
227            .await?;
228
229        let db_path = self.root.join(MANIFEST_DB_NAME);
230        let inner = Arc::new(Mutex::new(CacheInner::new(&db_path).await?));
231
232        let imp = CacheImpl::new(
233            self.root.clone(),
234            self.heartbeat_interval,
235            self.heartbeat_timeout,
236            inner,
237        );
238
239        let Server { local, remote, .. } = self;
240
241        let local_handle = if let Some(local) = local {
242            tracing::debug!("local server listening on address {}", local.local_addr()?);
243            let local_svc = LocalCacheServer::new(imp.clone());
244            Some(tokio::spawn(
245                tonic::transport::Server::builder()
246                    .add_service(local_svc)
247                    .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(local)),
248            ))
249        } else {
250            None
251        };
252        let remote_handle = if let Some(remote) = remote {
253            tracing::debug!(
254                "remote server listening on address {}",
255                remote.local_addr()?
256            );
257            let remote_svc = RemoteCacheServer::new(imp);
258            Some(tokio::spawn(
259                tonic::transport::Server::builder()
260                    .add_service(remote_svc)
261                    .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(remote)),
262            ))
263        } else {
264            None
265        };
266
267        if let Some(local_handle) = local_handle {
268            local_handle.await??;
269        }
270
271        if let Some(remote_handle) = remote_handle {
272            remote_handle.await??;
273        }
274
275        // Hold file lock until server terminates.
276        drop(config_manifest);
277
278        Ok(())
279    }
280}
281
282/// Cache state.
283#[derive(Clone, Debug)]
284struct CacheInner {
285    next_assignment_id: AssignmentId,
286    next_handle_id: HandleId,
287    /// Status of entries currently in the cache.
288    entry_status: HashMap<Arc<EntryKey>, EntryStatus>,
289    /// Status of entries that are currently loading.
290    loading: HashMap<AssignmentId, LoadingData>,
291    /// Status of entries that have active handles.
292    handles: HashMap<HandleId, Arc<EntryKey>>,
293    /// A wrapper around a [`tokio_rusqlite::Connection`].
294    conn: CacheInnerConn,
295}
296
297impl CacheInner {
298    async fn new(db_path: impl AsRef<Path>) -> Result<Self> {
299        tracing::debug!("connecting to manifest database");
300        // Set up the manifest database.
301        let conn = Connection::open(db_path.as_ref()).await?;
302        conn.call(|conn| {
303            let tx = conn.transaction()?;
304            tx.execute(CREATE_MANIFEST_TABLE_STMT, ())?;
305            tx.commit()?;
306            tracing::debug!("ensured that manifest table has been created");
307            Ok(())
308        })
309        .await?;
310
311        let mut cache = Self {
312            next_assignment_id: AssignmentId(0),
313            next_handle_id: HandleId(0),
314            entry_status: HashMap::new(),
315            loading: HashMap::new(),
316            handles: HashMap::new(),
317            conn: CacheInnerConn(conn),
318        };
319
320        // Load persisted state.
321        cache.load_from_disk().await?;
322
323        Ok(cache)
324    }
325
326    async fn load_from_disk(&mut self) -> Result<()> {
327        tracing::debug!("loading cache state from disk");
328        let rows = self
329            .conn
330            .0
331            .call(|conn| {
332                let tx = conn.transaction()?;
333
334                // Delete loading entries as we cannot recover assignment IDs on restart.
335                tracing::debug!("deleting loading entries from database");
336                let mut stmt = tx.prepare(DELETE_ENTRIES_WITH_STATUS_STMT)?;
337                stmt.execute([DbEntryStatus::Loading.to_int()])?;
338                drop(stmt);
339
340                // Read remaining rows from the manifest, converting them into tuples mapping
341                // `EntryKey` to a `DbEntryStatus`.
342                tracing::debug!("reading remaining entries from database");
343                let mut stmt = tx.prepare(READ_MANIFEST_STMT)?;
344                let rows = stmt.query_map(
345                    [],
346                    |row| -> rusqlite::Result<(Arc<EntryKey>, DbEntryStatus)> {
347                        Ok((
348                            Arc::new(EntryKey {
349                                namespace: Namespace::new(row.get::<_, String>(0)?),
350                                key: row.get(1)?,
351                            }),
352                            DbEntryStatus::from_int(row.get(2)?).unwrap(),
353                        ))
354                    },
355                )?;
356                let res = Ok(rows.collect::<Vec<_>>());
357                drop(stmt);
358
359                tx.commit()?;
360                res
361            })
362            .await?
363            .into_iter()
364            .map(|res| res.map_err(|e| e.into()))
365            .collect::<std::result::Result<Vec<_>, tokio_rusqlite::Error>>()?;
366
367        // Map database entries into in-memory cache state.
368        self.entry_status = HashMap::from_iter(rows.into_iter().filter_map(|v| {
369            Some((
370                v.0,
371                match v.1 {
372                    DbEntryStatus::Loading => None,
373                    DbEntryStatus::Ready => Some(EntryStatus::Ready(0)),
374                    DbEntryStatus::Evicting => Some(EntryStatus::Evicting),
375                }?,
376            ))
377        }));
378
379        Ok(())
380    }
381}
382
383#[derive(Clone, Debug)]
384struct CacheInnerConn(Connection);
385
386impl CacheInnerConn {
387    async fn insert_status(&self, key: Arc<EntryKey>, status: DbEntryStatus) -> Result<()> {
388        self.0
389            .call(move |conn| {
390                let mut stmt = conn.prepare(INSERT_STATUS_STMT)?;
391                stmt.execute((
392                    key.namespace.clone().into_inner(),
393                    key.key.clone(),
394                    status.to_int(),
395                ))?;
396                Ok(())
397            })
398            .await?;
399        Ok(())
400    }
401
402    async fn update_status(&self, key: Arc<EntryKey>, status: DbEntryStatus) -> Result<()> {
403        self.0
404            .call(move |conn| {
405                let mut stmt = conn.prepare(UPDATE_STATUS_STMT)?;
406                stmt.execute((
407                    status.to_int(),
408                    key.namespace.clone().into_inner(),
409                    key.key.clone(),
410                ))?;
411                Ok(())
412            })
413            .await?;
414        Ok(())
415    }
416
417    async fn delete_status(&self, key: Arc<EntryKey>) -> Result<()> {
418        self.0
419            .call(move |conn| {
420                let mut stmt = conn.prepare(DELETE_STATUS_STMT)?;
421                stmt.execute((key.namespace.clone().into_inner(), key.key.clone()))?;
422                Ok(())
423            })
424            .await?;
425        Ok(())
426    }
427}
428
429/// An ID corresponding to a client assigned to generate a certain value.
430#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
431struct AssignmentId(u64);
432
433impl AssignmentId {
434    fn increment(&mut self) {
435        self.0 += 1
436    }
437}
438
439/// An ID corresponding to a client that currently has a handle to a ready entry.
440#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
441struct HandleId(u64);
442
443impl HandleId {
444    fn increment(&mut self) {
445        self.0 += 1
446    }
447}
448
449#[derive(Clone, Debug, Hash, PartialEq, Eq)]
450struct EntryKey {
451    namespace: Namespace,
452    key: Vec<u8>,
453}
454
455#[derive(Clone, Copy, Debug)]
456enum EntryStatus {
457    Loading(AssignmentId),
458    /// Number of local requests that are using this entry.
459    Ready(u64),
460    Evicting,
461}
462
463#[derive(Clone, Copy, Debug)]
464enum DbEntryStatus {
465    Loading,
466    Ready,
467    /// An entry that is marked for eviction.
468    ///
469    /// Currently unused.
470    Evicting,
471}
472
473impl DbEntryStatus {
474    fn to_int(self) -> u64 {
475        match self {
476            Self::Loading => 0,
477            Self::Ready => 1,
478            Self::Evicting => 2,
479        }
480    }
481
482    fn from_int(val: u64) -> Option<Self> {
483        match val {
484            0 => Some(Self::Loading),
485            1 => Some(Self::Ready),
486            2 => Some(Self::Evicting),
487            _ => None,
488        }
489    }
490}
491
492#[derive(Clone, Debug)]
493struct LoadingData {
494    last_heartbeat: Instant,
495    key: Arc<EntryKey>,
496}
497
498#[derive(Clone, Debug)]
499enum GetReplyStatus {
500    Unassigned,
501    Assign(AssignmentId, Duration),
502    Loading,
503    ReadyRemote(Vec<u8>),
504    ReadyLocal(HandleId),
505}
506
507impl GetReplyStatus {
508    fn into_local(self, path: String) -> local::get_reply::EntryStatus {
509        match self {
510            Self::Unassigned => local::get_reply::EntryStatus::Unassigned(()),
511            Self::Assign(id, heartbeat_interval) => {
512                local::get_reply::EntryStatus::Assign(local::AssignReply {
513                    id: id.0,
514                    path,
515                    heartbeat_interval_ms: heartbeat_interval.as_millis() as u64,
516                })
517            }
518            Self::Loading => local::get_reply::EntryStatus::Loading(()),
519            Self::ReadyLocal(id) => {
520                local::get_reply::EntryStatus::Ready(local::ReadyReply { id: id.0, path })
521            }
522            Self::ReadyRemote(_) => panic!("cannot convert remote statuses to local statuses"),
523        }
524    }
525    fn into_remote(self) -> remote::get_reply::EntryStatus {
526        match self {
527            Self::Unassigned => remote::get_reply::EntryStatus::Unassigned(()),
528            Self::Assign(id, heartbeat_interval) => {
529                remote::get_reply::EntryStatus::Assign(remote::AssignReply {
530                    id: id.0,
531                    heartbeat_interval_ms: heartbeat_interval.as_millis() as u64,
532                })
533            }
534            Self::Loading => remote::get_reply::EntryStatus::Loading(()),
535            Self::ReadyRemote(val) => remote::get_reply::EntryStatus::Ready(val),
536            Self::ReadyLocal(_) => panic!("cannot convert local statuses to remote statuses"),
537        }
538    }
539}
540
541#[derive(Clone, Debug)]
542struct CacheImpl {
543    root: Arc<PathBuf>,
544    heartbeat_interval: Duration,
545    heartbeat_timeout: Duration,
546    inner: Arc<Mutex<CacheInner>>,
547}
548
549impl CacheImpl {
550    fn new(
551        root: Arc<PathBuf>,
552        heartbeat_interval: Duration,
553        heartbeat_timeout: Duration,
554        inner: Arc<Mutex<CacheInner>>,
555    ) -> Self {
556        Self {
557            root,
558            heartbeat_interval,
559            heartbeat_timeout,
560            inner,
561        }
562    }
563
564    /// Responds to a `Get` RPC request for the given entry key, assigning unassigned tasks if
565    /// `assign` is `true`.
566    ///
567    /// If `local` is `true`, getting an existing key in the cache requires assigning a new entry
568    /// handle, which must be dropped by the client to allow the key to be evicted.
569    async fn get_impl(
570        &self,
571        entry_key: Arc<EntryKey>,
572        assign: bool,
573        local: bool,
574    ) -> std::result::Result<GetReplyStatus, tonic::Status> {
575        tracing::debug!("received get request");
576        let mut inner = self.inner.lock().await;
577
578        let CacheInner {
579            next_assignment_id,
580            next_handle_id,
581            entry_status,
582            loading,
583            handles,
584            conn,
585            ..
586        } = &mut *inner;
587
588        let path = get_file(self.root.as_ref(), &entry_key);
589        Ok(match entry_status.entry(entry_key.clone()) {
590            Entry::Occupied(mut o) => {
591                let v = o.get_mut();
592                match v {
593                    EntryStatus::Loading(id) => {
594                        let data = loading
595                            .get(id)
596                            .ok_or(tonic::Status::internal("unable to retrieve status of key"))?;
597
598                        // If the entry is loading but hasn't received a heartbeat recently,
599                        // reassign it to be loaded by the new requester.
600                        //
601                        // Otherwise, notify the requester that the entry is currently loading.
602                        if Instant::now().duration_since(data.last_heartbeat)
603                            > self.heartbeat_timeout
604                        {
605                            tracing::debug!(
606                                "assigned worker has not sent a heartbeat recently, entry is no longer loading"
607                            );
608                            if assign {
609                                loading.remove(id);
610                                next_assignment_id.increment();
611                                *id = *next_assignment_id;
612                                tracing::debug!("assigning task with id {:?}", id);
613                                loading.insert(
614                                    *id,
615                                    LoadingData {
616                                        last_heartbeat: Instant::now(),
617                                        key: entry_key,
618                                    },
619                                );
620                                GetReplyStatus::Assign(*id, self.heartbeat_interval)
621                            } else {
622                                conn.delete_status(entry_key.clone()).await.map_err(|_| {
623                                    tonic::Status::internal("unable to persist changes")
624                                })?;
625                                o.remove_entry();
626                                GetReplyStatus::Unassigned
627                            }
628                        } else {
629                            tracing::debug!("entry is currently loading");
630                            GetReplyStatus::Loading
631                        }
632                    }
633                    EntryStatus::Ready(in_use) => {
634                        tracing::debug!("entry is ready, sending relevant data to client");
635                        if local {
636                            // If the requested entry is ready, assign a new handle to the entry.
637                            *in_use += 1;
638                            next_handle_id.increment();
639                            handles.insert(*next_handle_id, entry_key);
640                            GetReplyStatus::ReadyLocal(*next_handle_id)
641                        } else {
642                            // If the requested entry is ready, read it from disk and send it back
643                            // to the requester.
644                            let mut file = File::open(path).await?;
645                            let mut buf = Vec::new();
646                            file.read_to_end(&mut buf).await?;
647                            GetReplyStatus::ReadyRemote(buf)
648                        }
649                    }
650                    // If the entry is currently being evicted, do not assign it.
651                    //
652                    // The client is free to generate on their own, but the cache will not accept a
653                    // new value for the entry.
654                    EntryStatus::Evicting => {
655                        tracing::debug!("entry is currently being evicted");
656                        GetReplyStatus::Unassigned
657                    }
658                }
659            }
660            Entry::Vacant(v) => {
661                // If the entry doesn't exist, assign it to be loaded if needed.
662                tracing::debug!("entry does not exist, creating a new entry");
663                if assign {
664                    next_assignment_id.increment();
665                    conn.insert_status(entry_key.clone(), DbEntryStatus::Loading)
666                        .await
667                        .map_err(|_| tonic::Status::internal("unable to persist changes"))?;
668                    v.insert(EntryStatus::Loading(*next_assignment_id));
669                    tracing::debug!("assigning task with id {:?}", next_assignment_id);
670                    loading.insert(
671                        *next_assignment_id,
672                        LoadingData {
673                            last_heartbeat: Instant::now(),
674                            key: entry_key,
675                        },
676                    );
677                    GetReplyStatus::Assign(*next_assignment_id, self.heartbeat_interval)
678                } else {
679                    GetReplyStatus::Unassigned
680                }
681            }
682        })
683    }
684
685    async fn heartbeat_impl(&self, id: AssignmentId) -> std::result::Result<(), tonic::Status> {
686        tracing::debug!("received heartbeat request for id {:?}", id);
687        let mut inner = self.inner.lock().await;
688        match inner.loading.entry(id) {
689            Entry::Vacant(_) => {
690                tracing::error!(
691                    "received heartbeat request for invalid assignment id {:?}",
692                    id
693                );
694                return Err(tonic::Status::invalid_argument("invalid assignment id"));
695            }
696            Entry::Occupied(o) => {
697                o.into_mut().last_heartbeat = Instant::now();
698            }
699        }
700        Ok(())
701    }
702
703    async fn set_impl(
704        &self,
705        id: AssignmentId,
706        value: Option<Vec<u8>>,
707    ) -> std::result::Result<(), tonic::Status> {
708        tracing::debug!("received set request for id {:?}", id);
709        let mut inner = self.inner.lock().await;
710        let data = inner.loading.get(&id).ok_or_else(|| {
711            tracing::error!("received set request for invalid id {:?}", id);
712            tonic::Status::invalid_argument("invalid assignment id")
713        })?;
714
715        let key = data.key.clone();
716
717        // If there is a value to write to disk, write it to the appropriate file.
718        if let Some(value) = value {
719            let path = get_file(self.root.as_ref(), &key);
720
721            if let Some(parent) = path.parent() {
722                fs::create_dir_all(parent).await?;
723            }
724
725            let mut f = OpenOptions::new()
726                .read(true)
727                .write(true)
728                .create(true)
729                .truncate(true)
730                .open(&path)
731                .await?;
732            f.write_all(&value).await?;
733        }
734
735        // Mark the entry as ready in the database and in memory.
736        inner
737            .conn
738            .update_status(key.clone(), DbEntryStatus::Ready)
739            .await
740            .map_err(|_| tonic::Status::internal("unable to persist changes"))?;
741        let status = inner
742            .entry_status
743            .get_mut(&key)
744            .ok_or(tonic::Status::internal("unable to retrieve status of key"))?;
745        *status = EntryStatus::Ready(0);
746
747        Ok(())
748    }
749}
750
751#[tonic::async_trait]
752impl LocalCache for CacheImpl {
753    async fn get(
754        &self,
755        request: tonic::Request<local::GetRequest>,
756    ) -> std::result::Result<tonic::Response<local::GetReply>, tonic::Status> {
757        let request = request.into_inner();
758
759        if !Namespace::validate(&request.namespace) {
760            return Err(tonic::Status::invalid_argument("invalid namespace"));
761        }
762
763        let entry_key = Arc::new(EntryKey {
764            namespace: Namespace::new(request.namespace),
765            key: request.key,
766        });
767
768        let path = get_file(self.root.as_ref(), &entry_key)
769            .absolutize()
770            .unwrap()
771            .to_str()
772            .unwrap()
773            .to_string();
774
775        let entry_status = self
776            .get_impl(entry_key, request.assign, true)
777            .await?
778            .into_local(path);
779
780        Ok(Response::new(local::GetReply {
781            entry_status: Some(entry_status),
782        }))
783    }
784
785    async fn heartbeat(
786        &self,
787        request: tonic::Request<local::HeartbeatRequest>,
788    ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
789        self.heartbeat_impl(AssignmentId(request.into_inner().id))
790            .await?;
791        Ok(Response::new(()))
792    }
793
794    async fn done(
795        &self,
796        request: tonic::Request<local::DoneRequest>,
797    ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
798        let request = request.into_inner();
799        self.set_impl(AssignmentId(request.id), None).await?;
800        Ok(Response::new(()))
801    }
802
803    // TODO: Untested since eviction is not yet implemented.
804    async fn drop(
805        &self,
806        request: tonic::Request<local::DropRequest>,
807    ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
808        let request = request.into_inner();
809        let mut inner = self.inner.lock().await;
810
811        let CacheInner {
812            handles,
813            entry_status,
814            ..
815        } = &mut *inner;
816
817        let handle_id = HandleId(request.id);
818        let entry_key = handles
819            .get(&handle_id)
820            .ok_or(tonic::Status::invalid_argument("invalid handle id"))?;
821        let entry_status = entry_status
822            .get_mut(entry_key)
823            .ok_or(tonic::Status::internal("unable to retrieve status of key"))?;
824        if let EntryStatus::Ready(in_use) = entry_status {
825            *in_use -= 1;
826            handles.remove(&handle_id);
827        } else {
828            return Err(tonic::Status::internal("inconsistent internal state"));
829        }
830        Ok(Response::new(()))
831    }
832}
833
834#[tonic::async_trait]
835impl RemoteCache for CacheImpl {
836    async fn get(
837        &self,
838        request: tonic::Request<remote::GetRequest>,
839    ) -> std::result::Result<tonic::Response<remote::GetReply>, tonic::Status> {
840        let request = request.into_inner();
841
842        if !Namespace::validate(&request.namespace) {
843            return Err(tonic::Status::invalid_argument("invalid namespace"));
844        }
845
846        let entry_key = Arc::new(EntryKey {
847            namespace: Namespace::new(request.namespace),
848            key: request.key,
849        });
850
851        let entry_status = self
852            .get_impl(entry_key, request.assign, false)
853            .await?
854            .into_remote();
855
856        Ok(Response::new(remote::GetReply {
857            entry_status: Some(entry_status),
858        }))
859    }
860
861    async fn heartbeat(
862        &self,
863        request: tonic::Request<remote::HeartbeatRequest>,
864    ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
865        self.heartbeat_impl(AssignmentId(request.into_inner().id))
866            .await?;
867        Ok(Response::new(()))
868    }
869
870    async fn set(
871        &self,
872        request: tonic::Request<remote::SetRequest>,
873    ) -> std::result::Result<tonic::Response<()>, tonic::Status> {
874        let request = request.into_inner();
875        self.set_impl(AssignmentId(request.id), Some(request.value))
876            .await?;
877        Ok(Response::new(()))
878    }
879}
880
881fn get_file(root: impl AsRef<Path>, key: impl AsRef<EntryKey>) -> PathBuf {
882    let root = root.as_ref();
883    let key = key.as_ref();
884    // TODO: Require namespace to be filesystem compatible so that cache folder names don't need to
885    // be hashed.
886    root.join(key.namespace.as_ref())
887        .join(hex::encode(crate::hash(&key.key)))
888}