substrate/
lut.rs

1//! Lookup tables.
2
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use splines::{Key, Spline};
6
7/// A 1D lookup table.
8// TODO verify that length of keys and values match, and that k1 is sorted
9#[allow(missing_docs)]
10#[derive(Debug, Default, Clone, Eq, PartialEq, Builder, Serialize, Deserialize)]
11#[builder(pattern = "owned")]
12pub struct Lut1<K1, V> {
13    k1: Vec<K1>,
14    values: Vec<V>,
15}
16
17/// A 2D lookup table.
18#[allow(missing_docs)]
19#[derive(Debug, Default, Clone, Eq, PartialEq, Builder, Serialize, Deserialize)]
20#[builder(pattern = "owned")]
21pub struct Lut2<K1, K2, V> {
22    k1: Vec<K1>,
23    k2: Vec<K2>,
24    // row major order
25    values: Vec<Vec<V>>,
26}
27
28/// Extrapolation options.
29#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
30pub enum Extrapolation {
31    /// Do not extrapolate.
32    #[default]
33    None,
34    /// Rounds up to the first key in the lookup table's range.
35    RoundUp,
36}
37
38impl<K1, K2, V> Lut2<K1, K2, V> {
39    /// Create a new [`Lut2Builder`].
40    pub fn builder() -> Lut2Builder<K1, K2, V> {
41        Default::default()
42    }
43}
44
45impl<K1, K2, V> Lut2<K1, K2, V>
46where
47    K1: Ord,
48    K2: Ord,
49{
50    /// Lookup a value for the given keys.
51    pub fn get(&self, k1: &K1, k2: &K2) -> Option<&V> {
52        let i1 = self.k1.partition_point(|k| k < k1);
53        let i2 = self.k2.partition_point(|k| k < k2);
54        if k1 < self.k1.first()? || k2 < self.k2.first()? {
55            return None;
56        }
57        self.values.get(i1)?.get(i2)
58    }
59}
60
61impl FloatLut2 {
62    /// Lookup a value for the given keys, interpolating as necessary.
63    pub fn getf(&self, k1: f64, k2: f64) -> Option<f64> {
64        let interp1 = (0..self.k1.len())
65            .map(|i| {
66                Spline::from_vec(
67                    self.k2
68                        .iter()
69                        .copied()
70                        .zip(self.values.get(i)?.iter().copied())
71                        .map(|(k, v)| Key::new(k, v, splines::Interpolation::Linear))
72                        .collect(),
73                )
74                .sample(k2)
75            })
76            .collect::<Option<Vec<f64>>>()?;
77
78        Spline::from_vec(
79            self.k1
80                .iter()
81                .copied()
82                .zip(interp1)
83                .map(|(k, v)| Key::new(k, v, splines::Interpolation::Linear))
84                .collect(),
85        )
86        .sample(k1)
87    }
88
89    /// Lookup a value for the given keys, interpolating as necessary.
90    ///
91    /// Can extrapolate beyond the bounds of the key ranges.
92    pub fn getf_extrapolate(
93        &self,
94        mut k1: f64,
95        mut k2: f64,
96        extrapolate: Extrapolation,
97    ) -> Option<f64> {
98        if extrapolate == Extrapolation::RoundUp {
99            (k1, k2) = (k1.max(*self.k1.first()?), k2.max(*self.k2.first()?));
100        }
101
102        self.getf(k1, k2)
103    }
104}
105
106/// A floating point 1D LUT.
107pub type FloatLut1 = Lut1<f64, f64>;
108
109/// A floating point 2D LUT.
110pub type FloatLut2 = Lut2<f64, f64, f64>;
111
112#[cfg(test)]
113mod tests {
114    use approx::assert_relative_eq;
115
116    use super::*;
117
118    #[test]
119    fn test_lut_u64() {
120        let lut = Lut2::<u64, u64, u64>::builder()
121            .k1(vec![5, 6, 7])
122            .k2(vec![1, 2, 3])
123            .values(vec![vec![1, 5, 9], vec![2, 4, 8], vec![3, 6, 7]])
124            .build()
125            .unwrap();
126
127        assert_eq!(lut.get(&5, &2), Some(&5));
128        assert_eq!(lut.get(&4, &2), None);
129        assert_eq!(lut.get(&7, &3), Some(&7));
130        assert_eq!(lut.get(&8, &3), None);
131        assert_eq!(lut.get(&6, &4), None);
132        assert_eq!(lut.get(&6, &0), None);
133    }
134
135    #[test]
136    fn test_lut_f64() {
137        let lut = FloatLut2::builder()
138            .k1(vec![5., 6., 7.])
139            .k2(vec![1., 2., 3.])
140            .values(vec![vec![1., 5., 9.], vec![2., 4., 8.], vec![3., 6., 7.]])
141            .build()
142            .unwrap();
143
144        assert_relative_eq!(lut.getf(5., 2.).unwrap(), 5., max_relative = 1e-8);
145        assert_relative_eq!(lut.getf(5., 2.5).unwrap(), 7., max_relative = 1e-8);
146        assert_relative_eq!(lut.getf(6.5, 1.5).unwrap(), 3.75, max_relative = 1e-8);
147        assert_eq!(lut.getf(4.5, 2.5), None);
148    }
149}