1use fugue::{addr, ChoiceValue, Trace};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11use crate::error::GenomeError;
12use crate::genome::bounds::MultiBounds;
13use crate::genome::traits::EvolutionaryGenome;
14
15#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
17#[serde(bound = "")]
18pub enum TreeNode<T: Terminal, F: Function> {
19 Terminal(T),
21 Function(F, Vec<TreeNode<T, F>>),
23}
24
25impl<T: Terminal, F: Function> TreeNode<T, F> {
26 pub fn terminal(value: T) -> Self {
28 Self::Terminal(value)
29 }
30
31 pub fn function(func: F, children: Vec<Self>) -> Self {
33 Self::Function(func, children)
34 }
35
36 pub fn is_terminal(&self) -> bool {
38 matches!(self, Self::Terminal(_))
39 }
40
41 pub fn is_function(&self) -> bool {
43 matches!(self, Self::Function(_, _))
44 }
45
46 pub fn depth(&self) -> usize {
48 match self {
49 Self::Terminal(_) => 1,
50 Self::Function(_, children) => {
51 1 + children.iter().map(|c| c.depth()).max().unwrap_or(0)
52 }
53 }
54 }
55
56 pub fn size(&self) -> usize {
58 match self {
59 Self::Terminal(_) => 1,
60 Self::Function(_, children) => 1 + children.iter().map(|c| c.size()).sum::<usize>(),
61 }
62 }
63
64 pub fn positions(&self) -> Vec<Vec<usize>> {
66 let mut positions = Vec::new();
67 self.collect_positions(&[], &mut positions);
68 positions
69 }
70
71 fn collect_positions(&self, path: &[usize], positions: &mut Vec<Vec<usize>>) {
72 positions.push(path.to_vec());
73 if let Self::Function(_, children) = self {
74 for (i, child) in children.iter().enumerate() {
75 let mut child_path = path.to_vec();
76 child_path.push(i);
77 child.collect_positions(&child_path, positions);
78 }
79 }
80 }
81
82 pub fn get_subtree(&self, path: &[usize]) -> Option<&Self> {
84 if path.is_empty() {
85 return Some(self);
86 }
87
88 if let Self::Function(_, children) = self {
89 let idx = path[0];
90 if idx < children.len() {
91 children[idx].get_subtree(&path[1..])
92 } else {
93 None
94 }
95 } else {
96 None
97 }
98 }
99
100 pub fn get_subtree_mut(&mut self, path: &[usize]) -> Option<&mut Self> {
102 if path.is_empty() {
103 return Some(self);
104 }
105
106 if let Self::Function(_, children) = self {
107 let idx = path[0];
108 if idx < children.len() {
109 children[idx].get_subtree_mut(&path[1..])
110 } else {
111 None
112 }
113 } else {
114 None
115 }
116 }
117
118 pub fn replace_subtree(&mut self, path: &[usize], new_subtree: Self) -> bool {
120 if path.is_empty() {
121 *self = new_subtree;
122 return true;
123 }
124
125 if let Self::Function(_, children) = self {
126 let idx = path[0];
127 if idx < children.len() {
128 if path.len() == 1 {
129 children[idx] = new_subtree;
130 true
131 } else {
132 children[idx].replace_subtree(&path[1..], new_subtree)
133 }
134 } else {
135 false
136 }
137 } else {
138 false
139 }
140 }
141
142 pub fn terminal_positions(&self) -> Vec<Vec<usize>> {
144 let mut positions = Vec::new();
145 self.collect_terminal_positions(&[], &mut positions);
146 positions
147 }
148
149 fn collect_terminal_positions(&self, path: &[usize], positions: &mut Vec<Vec<usize>>) {
150 match self {
151 Self::Terminal(_) => positions.push(path.to_vec()),
152 Self::Function(_, children) => {
153 for (i, child) in children.iter().enumerate() {
154 let mut child_path = path.to_vec();
155 child_path.push(i);
156 child.collect_terminal_positions(&child_path, positions);
157 }
158 }
159 }
160 }
161
162 pub fn function_positions(&self) -> Vec<Vec<usize>> {
164 let mut positions = Vec::new();
165 self.collect_function_positions(&[], &mut positions);
166 positions
167 }
168
169 fn collect_function_positions(&self, path: &[usize], positions: &mut Vec<Vec<usize>>) {
170 if let Self::Function(_, children) = self {
171 positions.push(path.to_vec());
172 for (i, child) in children.iter().enumerate() {
173 let mut child_path = path.to_vec();
174 child_path.push(i);
175 child.collect_function_positions(&child_path, positions);
176 }
177 }
178 }
179}
180
181pub trait Terminal:
183 Clone + Send + Sync + PartialEq + fmt::Debug + Serialize + for<'de> Deserialize<'de> + 'static
184{
185 fn random<R: Rng>(rng: &mut R) -> Self;
187
188 fn terminals() -> &'static [Self];
190
191 fn evaluate(&self, variables: &[f64]) -> f64;
193
194 fn to_string(&self) -> String;
196}
197
198pub trait Function:
200 Clone + Send + Sync + PartialEq + fmt::Debug + Serialize + for<'de> Deserialize<'de> + 'static
201{
202 fn arity(&self) -> usize;
204
205 fn random<R: Rng>(rng: &mut R) -> Self;
207
208 fn functions() -> &'static [Self];
210
211 fn apply(&self, args: &[f64]) -> f64;
213
214 fn to_string(&self) -> String;
216}
217
218#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
220pub enum ArithmeticTerminal {
221 Variable(usize),
223 Constant(f64),
225 Erc(f64),
227}
228
229impl Terminal for ArithmeticTerminal {
230 fn random<R: Rng>(rng: &mut R) -> Self {
231 let choice: u8 = rng.gen_range(0..3);
232 match choice {
233 0 => Self::Variable(rng.gen_range(0..10)),
234 1 => Self::Constant(rng.gen_range(-10.0..10.0)),
235 _ => Self::Erc(rng.gen_range(-1.0..1.0)),
236 }
237 }
238
239 fn terminals() -> &'static [Self] {
240 &[]
242 }
243
244 fn evaluate(&self, variables: &[f64]) -> f64 {
245 match self {
246 Self::Variable(i) => variables.get(*i).copied().unwrap_or(0.0),
247 Self::Constant(c) | Self::Erc(c) => *c,
248 }
249 }
250
251 fn to_string(&self) -> String {
252 match self {
253 Self::Variable(i) => format!("x{}", i),
254 Self::Constant(c) | Self::Erc(c) => format!("{:.4}", c),
255 }
256 }
257}
258
259#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
261pub enum ArithmeticFunction {
262 Add,
264 Sub,
266 Mul,
268 Div,
270 Sin,
272 Cos,
274 Exp,
276 Log,
278 Sqrt,
280 Pow,
282 Neg,
284 Abs,
286}
287
288impl Function for ArithmeticFunction {
289 fn arity(&self) -> usize {
290 match self {
291 Self::Add | Self::Sub | Self::Mul | Self::Div | Self::Pow => 2,
292 Self::Sin | Self::Cos | Self::Exp | Self::Log | Self::Sqrt | Self::Neg | Self::Abs => 1,
293 }
294 }
295
296 fn random<R: Rng>(rng: &mut R) -> Self {
297 let funcs = Self::functions();
298 funcs[rng.gen_range(0..funcs.len())].clone()
299 }
300
301 fn functions() -> &'static [Self] {
302 &[
303 Self::Add,
304 Self::Sub,
305 Self::Mul,
306 Self::Div,
307 Self::Sin,
308 Self::Cos,
309 Self::Exp,
310 Self::Log,
311 Self::Sqrt,
312 Self::Neg,
313 Self::Abs,
314 ]
315 }
316
317 fn apply(&self, args: &[f64]) -> f64 {
318 match self {
319 Self::Add => args.get(0).unwrap_or(&0.0) + args.get(1).unwrap_or(&0.0),
320 Self::Sub => args.get(0).unwrap_or(&0.0) - args.get(1).unwrap_or(&0.0),
321 Self::Mul => args.get(0).unwrap_or(&1.0) * args.get(1).unwrap_or(&1.0),
322 Self::Div => {
323 let a = args.get(0).unwrap_or(&0.0);
324 let b = args.get(1).unwrap_or(&1.0);
325 if b.abs() < 1e-10 {
326 1.0 } else {
328 a / b
329 }
330 }
331 Self::Sin => args.get(0).unwrap_or(&0.0).sin(),
332 Self::Cos => args.get(0).unwrap_or(&0.0).cos(),
333 Self::Exp => {
334 let x = args.get(0).unwrap_or(&0.0);
335 if *x > 700.0 {
336 f64::MAX } else {
338 x.exp()
339 }
340 }
341 Self::Log => {
342 let x = args.get(0).unwrap_or(&1.0);
343 if *x <= 0.0 {
344 0.0 } else {
346 x.ln()
347 }
348 }
349 Self::Sqrt => {
350 let x = args.get(0).unwrap_or(&0.0);
351 if *x < 0.0 {
352 (-x).sqrt() } else {
354 x.sqrt()
355 }
356 }
357 Self::Pow => {
358 let base = args.get(0).unwrap_or(&1.0);
359 let exp = args.get(1).unwrap_or(&1.0);
360 if base.abs() < 1e-10 && *exp < 0.0 {
362 0.0
363 } else {
364 base.powf(*exp).clamp(-1e10, 1e10)
365 }
366 }
367 Self::Neg => -args.get(0).unwrap_or(&0.0),
368 Self::Abs => args.get(0).unwrap_or(&0.0).abs(),
369 }
370 }
371
372 fn to_string(&self) -> String {
373 match self {
374 Self::Add => "+".to_string(),
375 Self::Sub => "-".to_string(),
376 Self::Mul => "*".to_string(),
377 Self::Div => "/".to_string(),
378 Self::Sin => "sin".to_string(),
379 Self::Cos => "cos".to_string(),
380 Self::Exp => "exp".to_string(),
381 Self::Log => "log".to_string(),
382 Self::Sqrt => "sqrt".to_string(),
383 Self::Pow => "pow".to_string(),
384 Self::Neg => "neg".to_string(),
385 Self::Abs => "abs".to_string(),
386 }
387 }
388}
389
390#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
392#[serde(bound = "")]
393pub struct TreeGenome<T: Terminal = ArithmeticTerminal, F: Function = ArithmeticFunction> {
394 pub root: TreeNode<T, F>,
396 pub max_depth: usize,
398}
399
400impl<T: Terminal, F: Function> TreeGenome<T, F> {
401 pub fn new(root: TreeNode<T, F>, max_depth: usize) -> Self {
403 Self { root, max_depth }
404 }
405
406 pub fn depth(&self) -> usize {
408 self.root.depth()
409 }
410
411 pub fn size(&self) -> usize {
413 self.root.size()
414 }
415
416 pub fn evaluate(&self, variables: &[f64]) -> f64 {
418 self.evaluate_node(&self.root, variables)
419 }
420
421 fn evaluate_node(&self, node: &TreeNode<T, F>, variables: &[f64]) -> f64 {
422 match node {
423 TreeNode::Terminal(t) => t.evaluate(variables),
424 TreeNode::Function(f, children) => {
425 let args: Vec<f64> = children
426 .iter()
427 .map(|c| self.evaluate_node(c, variables))
428 .collect();
429 f.apply(&args)
430 }
431 }
432 }
433
434 pub fn generate_full<R: Rng>(rng: &mut R, depth: usize, max_depth: usize) -> Self {
436 let root = Self::generate_full_node(rng, depth, 0);
437 Self { root, max_depth }
438 }
439
440 fn generate_full_node<R: Rng>(
441 rng: &mut R,
442 target_depth: usize,
443 current_depth: usize,
444 ) -> TreeNode<T, F> {
445 if current_depth >= target_depth {
446 TreeNode::Terminal(T::random(rng))
447 } else {
448 let func = F::random(rng);
449 let arity = func.arity();
450 let children: Vec<TreeNode<T, F>> = (0..arity)
451 .map(|_| Self::generate_full_node(rng, target_depth, current_depth + 1))
452 .collect();
453 TreeNode::Function(func, children)
454 }
455 }
456
457 pub fn generate_grow<R: Rng>(rng: &mut R, max_depth: usize, terminal_prob: f64) -> Self {
459 let root = Self::generate_grow_node(rng, max_depth, 0, terminal_prob);
460 Self { root, max_depth }
461 }
462
463 fn generate_grow_node<R: Rng>(
464 rng: &mut R,
465 max_depth: usize,
466 current_depth: usize,
467 terminal_prob: f64,
468 ) -> TreeNode<T, F> {
469 if current_depth >= max_depth {
470 TreeNode::Terminal(T::random(rng))
471 } else if rng.gen::<f64>() < terminal_prob {
472 TreeNode::Terminal(T::random(rng))
473 } else {
474 let func = F::random(rng);
475 let arity = func.arity();
476 let children: Vec<TreeNode<T, F>> = (0..arity)
477 .map(|_| Self::generate_grow_node(rng, max_depth, current_depth + 1, terminal_prob))
478 .collect();
479 TreeNode::Function(func, children)
480 }
481 }
482
483 pub fn generate_ramped_half_and_half<R: Rng>(
485 rng: &mut R,
486 min_depth: usize,
487 max_depth: usize,
488 ) -> Self {
489 let depth = rng.gen_range(min_depth..=max_depth);
490 if rng.gen() {
491 Self::generate_full(rng, depth, max_depth)
492 } else {
493 Self::generate_grow(rng, depth, 0.3)
494 }
495 }
496
497 pub fn to_sexpr(&self) -> String {
499 self.node_to_sexpr(&self.root)
500 }
501
502 fn node_to_sexpr(&self, node: &TreeNode<T, F>) -> String {
503 match node {
504 TreeNode::Terminal(t) => t.to_string(),
505 TreeNode::Function(f, children) => {
506 let child_strs: Vec<String> =
507 children.iter().map(|c| self.node_to_sexpr(c)).collect();
508 format!("({} {})", f.to_string(), child_strs.join(" "))
509 }
510 }
511 }
512
513 pub fn random_position<R: Rng>(&self, rng: &mut R) -> Vec<usize> {
515 let positions = self.root.positions();
516 positions[rng.gen_range(0..positions.len())].clone()
517 }
518
519 pub fn random_terminal_position<R: Rng>(&self, rng: &mut R) -> Option<Vec<usize>> {
521 let positions = self.root.terminal_positions();
522 if positions.is_empty() {
523 None
524 } else {
525 Some(positions[rng.gen_range(0..positions.len())].clone())
526 }
527 }
528
529 pub fn random_function_position<R: Rng>(&self, rng: &mut R) -> Option<Vec<usize>> {
531 let positions = self.root.function_positions();
532 if positions.is_empty() {
533 None
534 } else {
535 Some(positions[rng.gen_range(0..positions.len())].clone())
536 }
537 }
538}
539
540impl<T: Terminal, F: Function> EvolutionaryGenome for TreeGenome<T, F> {
541 type Allele = TreeNode<T, F>;
542 type Phenotype = Self;
543
544 fn to_trace(&self) -> Trace {
545 let mut trace = Trace::default();
546 let mut index = 0;
547 self.node_to_trace(&self.root, &mut trace, &mut index);
548 trace.insert_choice(
550 addr!("tree_max_depth"),
551 ChoiceValue::Usize(self.max_depth),
552 0.0,
553 );
554 trace.insert_choice(addr!("tree_size"), ChoiceValue::Usize(index), 0.0);
555 trace
556 }
557
558 fn from_trace(trace: &Trace) -> Result<Self, GenomeError> {
559 let max_depth = trace
560 .get_usize(&addr!("tree_max_depth"))
561 .ok_or_else(|| GenomeError::MissingAddress("tree_max_depth".to_string()))?;
562
563 let mut index = 0;
564 let root = Self::node_from_trace(trace, &mut index)?;
565 Ok(Self { root, max_depth })
566 }
567
568 fn decode(&self) -> Self::Phenotype {
569 self.clone()
570 }
571
572 fn dimension(&self) -> usize {
573 self.size()
574 }
575
576 fn generate<R: Rng>(rng: &mut R, bounds: &MultiBounds) -> Self {
577 let max_depth = bounds.dimension().max(3).min(10);
578 Self::generate_ramped_half_and_half(rng, 2, max_depth)
579 }
580
581 fn distance(&self, other: &Self) -> f64 {
582 let size_diff = (self.size() as f64 - other.size() as f64).abs();
584 let depth_diff = (self.depth() as f64 - other.depth() as f64).abs();
585 size_diff + depth_diff
586 }
587
588 fn trace_prefix() -> &'static str {
589 "tree"
590 }
591}
592
593impl<T: Terminal, F: Function> TreeGenome<T, F> {
594 fn node_to_trace(&self, node: &TreeNode<T, F>, trace: &mut Trace, index: &mut usize) {
595 let current_index = *index;
596 *index += 1;
597
598 match node {
599 TreeNode::Terminal(t) => {
600 trace.insert_choice(
602 addr!("tree_is_terminal", current_index),
603 ChoiceValue::Bool(true),
604 0.0,
605 );
606 let (term_type, term_val) = Self::encode_terminal(t);
609 trace.insert_choice(
610 addr!("tree_term_type", current_index),
611 ChoiceValue::F64(term_type),
612 0.0,
613 );
614 trace.insert_choice(
615 addr!("tree_term_val", current_index),
616 ChoiceValue::F64(term_val),
617 0.0,
618 );
619 }
620 TreeNode::Function(f, children) => {
621 trace.insert_choice(
623 addr!("tree_is_terminal", current_index),
624 ChoiceValue::Bool(false),
625 0.0,
626 );
627 let func_idx = Self::encode_function(f);
629 trace.insert_choice(
630 addr!("tree_func_idx", current_index),
631 ChoiceValue::Usize(func_idx),
632 0.0,
633 );
634 trace.insert_choice(
635 addr!("tree_arity", current_index),
636 ChoiceValue::Usize(children.len()),
637 0.0,
638 );
639 for child in children {
641 self.node_to_trace(child, trace, index);
642 }
643 }
644 }
645 }
646
647 fn node_from_trace(trace: &Trace, index: &mut usize) -> Result<TreeNode<T, F>, GenomeError> {
648 let current_index = *index;
649 *index += 1;
650
651 let is_terminal = trace
652 .get_bool(&addr!("tree_is_terminal", current_index))
653 .ok_or_else(|| {
654 GenomeError::MissingAddress(format!("tree_is_terminal#{}", current_index))
655 })?;
656
657 if is_terminal {
658 let term_type = trace
659 .get_f64(&addr!("tree_term_type", current_index))
660 .ok_or_else(|| {
661 GenomeError::MissingAddress(format!("tree_term_type#{}", current_index))
662 })?;
663 let term_val = trace
664 .get_f64(&addr!("tree_term_val", current_index))
665 .ok_or_else(|| {
666 GenomeError::MissingAddress(format!("tree_term_val#{}", current_index))
667 })?;
668
669 let terminal = Self::decode_terminal(term_type, term_val)?;
670 Ok(TreeNode::Terminal(terminal))
671 } else {
672 let func_idx = trace
673 .get_usize(&addr!("tree_func_idx", current_index))
674 .ok_or_else(|| {
675 GenomeError::MissingAddress(format!("tree_func_idx#{}", current_index))
676 })?;
677 let arity = trace
678 .get_usize(&addr!("tree_arity", current_index))
679 .ok_or_else(|| {
680 GenomeError::MissingAddress(format!("tree_arity#{}", current_index))
681 })?;
682
683 let func = Self::decode_function(func_idx)?;
684 let mut children = Vec::with_capacity(arity);
685 for _ in 0..arity {
686 children.push(Self::node_from_trace(trace, index)?);
687 }
688 Ok(TreeNode::Function(func, children))
689 }
690 }
691
692 fn encode_terminal(_terminal: &T) -> (f64, f64) {
695 (0.0, 0.0)
698 }
699
700 fn decode_terminal(term_type: f64, term_val: f64) -> Result<T, GenomeError> {
701 let mut rng = rand::thread_rng();
704 let _ = (term_type, term_val); Ok(T::random(&mut rng))
706 }
707
708 fn encode_function(_func: &F) -> usize {
709 0
711 }
712
713 fn decode_function(func_idx: usize) -> Result<F, GenomeError> {
714 let funcs = F::functions();
715 if func_idx < funcs.len() {
716 Ok(funcs[func_idx].clone())
717 } else {
718 let mut rng = rand::thread_rng();
720 Ok(F::random(&mut rng))
721 }
722 }
723}
724
725impl<T: Terminal, F: Function> fmt::Display for TreeGenome<T, F> {
726 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
727 write!(f, "{}", self.to_sexpr())
728 }
729}
730
731pub trait TreeGenomeType: EvolutionaryGenome {
733 type Term: Terminal;
735 type Func: Function;
737
738 fn root(&self) -> &TreeNode<Self::Term, Self::Func>;
740
741 fn root_mut(&mut self) -> &mut TreeNode<Self::Term, Self::Func>;
743
744 fn max_depth(&self) -> usize;
746
747 fn from_root(root: TreeNode<Self::Term, Self::Func>, max_depth: usize) -> Self;
749}
750
751impl<T: Terminal, F: Function> TreeGenomeType for TreeGenome<T, F> {
752 type Term = T;
753 type Func = F;
754
755 fn root(&self) -> &TreeNode<T, F> {
756 &self.root
757 }
758
759 fn root_mut(&mut self) -> &mut TreeNode<T, F> {
760 &mut self.root
761 }
762
763 fn max_depth(&self) -> usize {
764 self.max_depth
765 }
766
767 fn from_root(root: TreeNode<T, F>, max_depth: usize) -> Self {
768 Self { root, max_depth }
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775
776 #[test]
777 fn test_tree_node_terminal() {
778 let node: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
779 TreeNode::terminal(ArithmeticTerminal::Variable(0));
780 assert!(node.is_terminal());
781 assert!(!node.is_function());
782 assert_eq!(node.depth(), 1);
783 assert_eq!(node.size(), 1);
784 }
785
786 #[test]
787 fn test_tree_node_function() {
788 let left = TreeNode::terminal(ArithmeticTerminal::Variable(0));
789 let right = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
790 let node = TreeNode::function(ArithmeticFunction::Add, vec![left, right]);
791
792 assert!(!node.is_terminal());
793 assert!(node.is_function());
794 assert_eq!(node.depth(), 2);
795 assert_eq!(node.size(), 3);
796 }
797
798 #[test]
799 fn test_tree_node_positions() {
800 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
802 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
803 let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
804 let mul = TreeNode::function(ArithmeticFunction::Mul, vec![c1, x1]);
805 let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, mul]);
806
807 let positions = add.positions();
808 assert_eq!(positions.len(), 5); assert!(positions.contains(&vec![])); assert!(positions.contains(&vec![0])); assert!(positions.contains(&vec![1])); assert!(positions.contains(&vec![1, 0])); assert!(positions.contains(&vec![1, 1])); }
815
816 #[test]
817 fn test_tree_node_get_subtree() {
818 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
819 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
820 let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
821 TreeNode::function(ArithmeticFunction::Add, vec![x0.clone(), c1]);
822
823 assert_eq!(add.get_subtree(&[0]), Some(&x0));
824 assert!(add.get_subtree(&[2]).is_none());
825 }
826
827 #[test]
828 fn test_tree_genome_evaluate() {
829 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
831 let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
832 let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, x1]);
833 let tree = TreeGenome::new(add, 5);
834
835 assert_eq!(tree.evaluate(&[3.0, 4.0]), 7.0);
836 }
837
838 #[test]
839 fn test_tree_genome_evaluate_complex() {
840 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
842 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
843 let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
844 let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
845 let mul = TreeNode::function(ArithmeticFunction::Mul, vec![add, x1]);
846 let tree = TreeGenome::new(mul, 5);
847
848 assert_eq!(tree.evaluate(&[2.0, 3.0]), 9.0); }
850
851 #[test]
852 fn test_tree_genome_generate_full() {
853 let mut rng = rand::thread_rng();
854 let tree: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
855 TreeGenome::generate_full(&mut rng, 3, 5);
856
857 assert!(tree.depth() >= 3);
860 assert!(tree.size() >= 1);
861 }
862
863 #[test]
864 fn test_tree_genome_generate_grow() {
865 let mut rng = rand::thread_rng();
866 let tree: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
867 TreeGenome::generate_grow(&mut rng, 5, 0.3);
868
869 assert!(tree.depth() <= 6);
871 assert!(tree.size() >= 1);
872 }
873
874 #[test]
875 fn test_tree_genome_to_sexpr() {
876 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
877 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
878 let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
879 TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
880 let tree = TreeGenome::new(add, 5);
881
882 let sexpr = tree.to_sexpr();
883 assert!(sexpr.contains('+'));
884 assert!(sexpr.contains("x0"));
885 assert!(sexpr.contains("1.0"));
886 }
887
888 #[test]
889 fn test_tree_genome_trace_roundtrip() {
890 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
891 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(2.5));
892 let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
893 let original: TreeGenome<ArithmeticTerminal, ArithmeticFunction> = TreeGenome::new(add, 5);
894
895 let trace = original.to_trace();
896 let recovered: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
897 TreeGenome::from_trace(&trace).unwrap();
898
899 assert_eq!(original.max_depth, recovered.max_depth);
902 assert_eq!(original.size(), recovered.size());
903 assert!(recovered.evaluate(&[3.0]).is_finite());
905 }
906
907 #[test]
908 fn test_tree_node_replace_subtree() {
909 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
910 let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
911 let mut add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
912 TreeNode::function(ArithmeticFunction::Add, vec![x0, x1]);
913
914 let c5 = TreeNode::terminal(ArithmeticTerminal::Constant(5.0));
915 add.replace_subtree(&[0], c5);
916
917 let tree = TreeGenome::new(add, 5);
919 assert_eq!(tree.evaluate(&[0.0, 3.0]), 8.0); }
921
922 #[test]
923 fn test_arithmetic_function_protected_div() {
924 assert_eq!(ArithmeticFunction::Div.apply(&[1.0, 0.0]), 1.0);
925 assert_eq!(ArithmeticFunction::Div.apply(&[6.0, 2.0]), 3.0);
926 }
927
928 #[test]
929 fn test_arithmetic_function_protected_log() {
930 assert_eq!(ArithmeticFunction::Log.apply(&[-1.0]), 0.0);
931 assert!((ArithmeticFunction::Log.apply(&[std::f64::consts::E]) - 1.0).abs() < 0.001);
932 }
933
934 #[test]
935 fn test_arithmetic_function_protected_sqrt() {
936 assert_eq!(ArithmeticFunction::Sqrt.apply(&[4.0]), 2.0);
937 assert_eq!(ArithmeticFunction::Sqrt.apply(&[-4.0]), 2.0); }
939
940 #[test]
941 fn test_tree_genome_evolutionary_genome_trait() {
942 let mut rng = rand::thread_rng();
943 let bounds = MultiBounds::symmetric(5.0, 5);
944 let tree: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
945 TreeGenome::generate(&mut rng, &bounds);
946
947 assert!(tree.dimension() >= 1);
948 let decoded = tree.decode();
949 assert_eq!(decoded.size(), tree.size());
950 }
951
952 #[test]
953 fn test_tree_terminal_and_function_positions() {
954 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
956 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
957 let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
958 let mul = TreeNode::function(ArithmeticFunction::Mul, vec![c1, x1]);
959 let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
960 TreeNode::function(ArithmeticFunction::Add, vec![x0, mul]);
961
962 let terminal_positions = add.terminal_positions();
963 assert_eq!(terminal_positions.len(), 3); let function_positions = add.function_positions();
966 assert_eq!(function_positions.len(), 2); }
968
969 #[test]
970 fn test_tree_genome_display() {
971 let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
972 let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
973 let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
974 TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
975 let tree = TreeGenome::new(add, 5);
976
977 let display = format!("{}", tree);
978 assert!(!display.is_empty());
979 }
980}