1use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use splines::{Key, Spline};
6
7#[allow(missing_docs)]
10#[derive(Debug, Default, Clone, Eq, PartialEq, Builder, Serialize, Deserialize)]
11#[builder(pattern = "owned")]
12pub struct Lut1<K1, V> {
13 k1: Vec<K1>,
14 values: Vec<V>,
15}
16
17#[allow(missing_docs)]
19#[derive(Debug, Default, Clone, Eq, PartialEq, Builder, Serialize, Deserialize)]
20#[builder(pattern = "owned")]
21pub struct Lut2<K1, K2, V> {
22 k1: Vec<K1>,
23 k2: Vec<K2>,
24 values: Vec<Vec<V>>,
26}
27
28#[derive(Debug, Default, Clone, Copy, Eq, PartialEq)]
30pub enum Extrapolation {
31 #[default]
33 None,
34 RoundUp,
36}
37
38impl<K1, K2, V> Lut2<K1, K2, V> {
39 pub fn builder() -> Lut2Builder<K1, K2, V> {
41 Default::default()
42 }
43}
44
45impl<K1, K2, V> Lut2<K1, K2, V>
46where
47 K1: Ord,
48 K2: Ord,
49{
50 pub fn get(&self, k1: &K1, k2: &K2) -> Option<&V> {
52 let i1 = self.k1.partition_point(|k| k < k1);
53 let i2 = self.k2.partition_point(|k| k < k2);
54 if k1 < self.k1.first()? || k2 < self.k2.first()? {
55 return None;
56 }
57 self.values.get(i1)?.get(i2)
58 }
59}
60
61impl FloatLut2 {
62 pub fn getf(&self, k1: f64, k2: f64) -> Option<f64> {
64 let interp1 = (0..self.k1.len())
65 .map(|i| {
66 Spline::from_vec(
67 self.k2
68 .iter()
69 .copied()
70 .zip(self.values.get(i)?.iter().copied())
71 .map(|(k, v)| Key::new(k, v, splines::Interpolation::Linear))
72 .collect(),
73 )
74 .sample(k2)
75 })
76 .collect::<Option<Vec<f64>>>()?;
77
78 Spline::from_vec(
79 self.k1
80 .iter()
81 .copied()
82 .zip(interp1)
83 .map(|(k, v)| Key::new(k, v, splines::Interpolation::Linear))
84 .collect(),
85 )
86 .sample(k1)
87 }
88
89 pub fn getf_extrapolate(
93 &self,
94 mut k1: f64,
95 mut k2: f64,
96 extrapolate: Extrapolation,
97 ) -> Option<f64> {
98 if extrapolate == Extrapolation::RoundUp {
99 (k1, k2) = (k1.max(*self.k1.first()?), k2.max(*self.k2.first()?));
100 }
101
102 self.getf(k1, k2)
103 }
104}
105
106pub type FloatLut1 = Lut1<f64, f64>;
108
109pub type FloatLut2 = Lut2<f64, f64, f64>;
111
112#[cfg(test)]
113mod tests {
114 use approx::assert_relative_eq;
115
116 use super::*;
117
118 #[test]
119 fn test_lut_u64() {
120 let lut = Lut2::<u64, u64, u64>::builder()
121 .k1(vec![5, 6, 7])
122 .k2(vec![1, 2, 3])
123 .values(vec![vec![1, 5, 9], vec![2, 4, 8], vec![3, 6, 7]])
124 .build()
125 .unwrap();
126
127 assert_eq!(lut.get(&5, &2), Some(&5));
128 assert_eq!(lut.get(&4, &2), None);
129 assert_eq!(lut.get(&7, &3), Some(&7));
130 assert_eq!(lut.get(&8, &3), None);
131 assert_eq!(lut.get(&6, &4), None);
132 assert_eq!(lut.get(&6, &0), None);
133 }
134
135 #[test]
136 fn test_lut_f64() {
137 let lut = FloatLut2::builder()
138 .k1(vec![5., 6., 7.])
139 .k2(vec![1., 2., 3.])
140 .values(vec![vec![1., 5., 9.], vec![2., 4., 8.], vec![3., 6., 7.]])
141 .build()
142 .unwrap();
143
144 assert_relative_eq!(lut.getf(5., 2.).unwrap(), 5., max_relative = 1e-8);
145 assert_relative_eq!(lut.getf(5., 2.5).unwrap(), 7., max_relative = 1e-8);
146 assert_relative_eq!(lut.getf(6.5, 1.5).unwrap(), 3.75, max_relative = 1e-8);
147 assert_eq!(lut.getf(4.5, 2.5), None);
148 }
149}