1use fugue::{addr, ChoiceValue, Trace};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10use crate::error::GenomeError;
11use crate::genome::bounds::MultiBounds;
12use crate::genome::traits::{EvolutionaryGenome, RealValuedGenome};
13
14#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
19pub struct RealVector {
20 genes: Vec<f64>,
22}
23
24impl RealVector {
25 pub fn new(genes: Vec<f64>) -> Self {
27 Self { genes }
28 }
29
30 pub fn zeros(dimension: usize) -> Self {
32 Self {
33 genes: vec![0.0; dimension],
34 }
35 }
36
37 pub fn filled(dimension: usize, value: f64) -> Self {
39 Self {
40 genes: vec![value; dimension],
41 }
42 }
43
44 pub fn collect_from<I: IntoIterator<Item = f64>>(iter: I) -> Self {
46 Self {
47 genes: iter.into_iter().collect(),
48 }
49 }
50
51 pub fn into_inner(self) -> Vec<f64> {
53 self.genes
54 }
55
56 pub fn as_vec(&self) -> &Vec<f64> {
58 &self.genes
59 }
60
61 pub fn norm(&self) -> f64 {
63 self.genes.iter().map(|x| x * x).sum::<f64>().sqrt()
64 }
65
66 pub fn norm_squared(&self) -> f64 {
68 self.genes.iter().map(|x| x * x).sum::<f64>()
69 }
70
71 pub fn add(&self, other: &Self) -> Result<Self, GenomeError> {
73 if self.genes.len() != other.genes.len() {
74 return Err(GenomeError::DimensionMismatch {
75 expected: self.genes.len(),
76 actual: other.genes.len(),
77 });
78 }
79 Ok(Self {
80 genes: self
81 .genes
82 .iter()
83 .zip(other.genes.iter())
84 .map(|(a, b)| a + b)
85 .collect(),
86 })
87 }
88
89 pub fn sub(&self, other: &Self) -> Result<Self, GenomeError> {
91 if self.genes.len() != other.genes.len() {
92 return Err(GenomeError::DimensionMismatch {
93 expected: self.genes.len(),
94 actual: other.genes.len(),
95 });
96 }
97 Ok(Self {
98 genes: self
99 .genes
100 .iter()
101 .zip(other.genes.iter())
102 .map(|(a, b)| a - b)
103 .collect(),
104 })
105 }
106
107 pub fn scale(&self, scalar: f64) -> Self {
109 Self {
110 genes: self.genes.iter().map(|x| x * scalar).collect(),
111 }
112 }
113}
114
115impl EvolutionaryGenome for RealVector {
116 type Allele = f64;
117 type Phenotype = Vec<f64>;
118
119 fn to_trace(&self) -> Trace {
123 let mut trace = Trace::default();
124 for (i, &gene) in self.genes.iter().enumerate() {
125 trace.insert_choice(addr!("gene", i), ChoiceValue::F64(gene), 0.0);
126 }
127 trace
128 }
129
130 fn from_trace(trace: &Trace) -> Result<Self, GenomeError> {
134 let mut genes = Vec::new();
135 let mut i = 0;
136 while let Some(val) = trace.get_f64(&addr!("gene", i)) {
137 genes.push(val);
138 i += 1;
139 }
140 if genes.is_empty() {
141 return Err(GenomeError::InvalidStructure(
142 "No genes found in trace".to_string(),
143 ));
144 }
145 Ok(Self { genes })
146 }
147
148 fn decode(&self) -> Self::Phenotype {
149 self.genes.clone()
150 }
151
152 fn dimension(&self) -> usize {
153 self.genes.len()
154 }
155
156 fn generate<R: Rng>(rng: &mut R, bounds: &MultiBounds) -> Self {
157 let genes = bounds
158 .bounds
159 .iter()
160 .map(|b| rng.gen_range(b.min..=b.max))
161 .collect();
162 Self { genes }
163 }
164
165 fn as_slice(&self) -> Option<&[f64]> {
166 Some(&self.genes)
167 }
168
169 fn as_mut_slice(&mut self) -> Option<&mut [f64]> {
170 Some(&mut self.genes)
171 }
172
173 fn distance(&self, other: &Self) -> f64 {
174 self.genes
175 .iter()
176 .zip(other.genes.iter())
177 .map(|(a, b)| (a - b).powi(2))
178 .sum::<f64>()
179 .sqrt()
180 }
181}
182
183impl RealValuedGenome for RealVector {
184 fn genes(&self) -> &[f64] {
185 &self.genes
186 }
187
188 fn genes_mut(&mut self) -> &mut [f64] {
189 &mut self.genes
190 }
191
192 fn from_genes(genes: Vec<f64>) -> Result<Self, GenomeError> {
193 Ok(Self { genes })
194 }
195}
196
197impl std::ops::Index<usize> for RealVector {
198 type Output = f64;
199
200 fn index(&self, index: usize) -> &Self::Output {
201 &self.genes[index]
202 }
203}
204
205impl std::ops::IndexMut<usize> for RealVector {
206 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
207 &mut self.genes[index]
208 }
209}
210
211impl From<Vec<f64>> for RealVector {
212 fn from(genes: Vec<f64>) -> Self {
213 Self { genes }
214 }
215}
216
217impl From<RealVector> for Vec<f64> {
218 fn from(genome: RealVector) -> Self {
219 genome.genes
220 }
221}
222
223impl<const N: usize> From<[f64; N]> for RealVector {
224 fn from(arr: [f64; N]) -> Self {
225 Self {
226 genes: arr.to_vec(),
227 }
228 }
229}
230
231impl IntoIterator for RealVector {
232 type Item = f64;
233 type IntoIter = std::vec::IntoIter<f64>;
234
235 fn into_iter(self) -> Self::IntoIter {
236 self.genes.into_iter()
237 }
238}
239
240impl<'a> IntoIterator for &'a RealVector {
241 type Item = &'a f64;
242 type IntoIter = std::slice::Iter<'a, f64>;
243
244 fn into_iter(self) -> Self::IntoIter {
245 self.genes.iter()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use approx::assert_relative_eq;
253 use fugue::addr;
254
255 #[test]
256 fn test_real_vector_new() {
257 let v = RealVector::new(vec![1.0, 2.0, 3.0]);
258 assert_eq!(v.dimension(), 3);
259 assert_eq!(v.genes(), &[1.0, 2.0, 3.0]);
260 }
261
262 #[test]
263 fn test_real_vector_zeros() {
264 let v = RealVector::zeros(5);
265 assert_eq!(v.dimension(), 5);
266 assert!(v.genes().iter().all(|&x| x == 0.0));
267 }
268
269 #[test]
270 fn test_real_vector_filled() {
271 let v = RealVector::filled(3, 42.0);
272 assert_eq!(v.genes(), &[42.0, 42.0, 42.0]);
273 }
274
275 #[test]
276 fn test_real_vector_from_array() {
277 let v: RealVector = [1.0, 2.0, 3.0].into();
278 assert_eq!(v.genes(), &[1.0, 2.0, 3.0]);
279 }
280
281 #[test]
282 fn test_real_vector_decode() {
283 let v = RealVector::new(vec![1.0, 2.0, 3.0]);
284 let phenotype = v.decode();
285 assert_eq!(phenotype, vec![1.0, 2.0, 3.0]);
286 }
287
288 #[test]
289 fn test_real_vector_generate() {
290 let mut rng = rand::thread_rng();
291 let bounds = MultiBounds::symmetric(5.0, 10);
292 let v = RealVector::generate(&mut rng, &bounds);
293
294 assert_eq!(v.dimension(), 10);
295 for gene in v.genes() {
296 assert!(*gene >= -5.0 && *gene <= 5.0);
297 }
298 }
299
300 #[test]
301 fn test_real_vector_norm() {
302 let v = RealVector::new(vec![3.0, 4.0]);
303 assert_relative_eq!(v.norm(), 5.0);
304 assert_relative_eq!(v.norm_squared(), 25.0);
305 }
306
307 #[test]
308 fn test_real_vector_distance() {
309 let v1 = RealVector::new(vec![0.0, 0.0]);
310 let v2 = RealVector::new(vec![3.0, 4.0]);
311 assert_relative_eq!(v1.distance(&v2), 5.0);
312 }
313
314 #[test]
315 fn test_real_vector_add() {
316 let v1 = RealVector::new(vec![1.0, 2.0, 3.0]);
317 let v2 = RealVector::new(vec![4.0, 5.0, 6.0]);
318 let sum = v1.add(&v2).unwrap();
319 assert_eq!(sum.genes(), &[5.0, 7.0, 9.0]);
320 }
321
322 #[test]
323 fn test_real_vector_add_dimension_mismatch() {
324 let v1 = RealVector::new(vec![1.0, 2.0]);
325 let v2 = RealVector::new(vec![1.0, 2.0, 3.0]);
326 let result = v1.add(&v2);
327 assert!(result.is_err());
328 assert!(matches!(
329 result.unwrap_err(),
330 GenomeError::DimensionMismatch {
331 expected: 2,
332 actual: 3
333 }
334 ));
335 }
336
337 #[test]
338 fn test_real_vector_sub() {
339 let v1 = RealVector::new(vec![5.0, 7.0, 9.0]);
340 let v2 = RealVector::new(vec![1.0, 2.0, 3.0]);
341 let diff = v1.sub(&v2).unwrap();
342 assert_eq!(diff.genes(), &[4.0, 5.0, 6.0]);
343 }
344
345 #[test]
346 fn test_real_vector_scale() {
347 let v = RealVector::new(vec![1.0, 2.0, 3.0]);
348 let scaled = v.scale(2.0);
349 assert_eq!(scaled.genes(), &[2.0, 4.0, 6.0]);
350 }
351
352 #[test]
353 fn test_real_vector_indexing() {
354 let mut v = RealVector::new(vec![1.0, 2.0, 3.0]);
355 assert_eq!(v[0], 1.0);
356 assert_eq!(v[1], 2.0);
357 assert_eq!(v[2], 3.0);
358
359 v[1] = 42.0;
360 assert_eq!(v[1], 42.0);
361 }
362
363 #[test]
364 fn test_real_vector_apply_bounds() {
365 let mut v = RealVector::new(vec![-10.0, 0.0, 10.0]);
366 let bounds = MultiBounds::symmetric(5.0, 3);
367 v.apply_bounds(&bounds);
368 assert_eq!(v.genes(), &[-5.0, 0.0, 5.0]);
369 }
370
371 #[test]
372 fn test_real_vector_iteration() {
373 let v = RealVector::new(vec![1.0, 2.0, 3.0]);
374 let sum: f64 = v.into_iter().sum();
375 assert_relative_eq!(sum, 6.0);
376 }
377
378 #[test]
379 fn test_real_vector_into_inner() {
380 let v = RealVector::new(vec![1.0, 2.0, 3.0]);
381 let inner: Vec<f64> = v.into_inner();
382 assert_eq!(inner, vec![1.0, 2.0, 3.0]);
383 }
384
385 #[test]
386 fn test_real_vector_serialization() {
387 let v = RealVector::new(vec![1.0, 2.0, 3.0]);
388 let serialized = serde_json::to_string(&v).unwrap();
389 let deserialized: RealVector = serde_json::from_str(&serialized).unwrap();
390 assert_eq!(v, deserialized);
391 }
392
393 #[test]
394 fn test_real_vector_to_trace() {
395 let v = RealVector::new(vec![1.5, 2.5, 3.5]);
396 let trace = v.to_trace();
397
398 assert_eq!(trace.get_f64(&addr!("gene", 0)), Some(1.5));
399 assert_eq!(trace.get_f64(&addr!("gene", 1)), Some(2.5));
400 assert_eq!(trace.get_f64(&addr!("gene", 2)), Some(3.5));
401 assert_eq!(trace.get_f64(&addr!("gene", 3)), None);
402 }
403
404 #[test]
405 fn test_real_vector_from_trace() {
406 let mut trace = Trace::default();
407 trace.insert_choice(addr!("gene", 0), ChoiceValue::F64(1.0), 0.0);
408 trace.insert_choice(addr!("gene", 1), ChoiceValue::F64(2.0), 0.0);
409 trace.insert_choice(addr!("gene", 2), ChoiceValue::F64(3.0), 0.0);
410
411 let v = RealVector::from_trace(&trace).unwrap();
412 assert_eq!(v.genes(), &[1.0, 2.0, 3.0]);
413 }
414
415 #[test]
416 fn test_real_vector_trace_roundtrip() {
417 let original = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
418 let trace = original.to_trace();
419 let recovered = RealVector::from_trace(&trace).unwrap();
420 assert_eq!(original, recovered);
421 }
422
423 #[test]
424 fn test_real_vector_from_trace_empty() {
425 let trace = Trace::default();
426 let result = RealVector::from_trace(&trace);
427 assert!(result.is_err());
428 }
429}