Skip to main content

fugue_evo/genome/
tree.rs

1//! Tree genomes for genetic programming
2//!
3//! This module provides tree-based genomes for symbolic regression and
4//! genetic programming applications.
5
6use fugue::{addr, ChoiceValue, Trace};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9use std::fmt;
10
11use crate::error::GenomeError;
12use crate::genome::bounds::MultiBounds;
13use crate::genome::traits::EvolutionaryGenome;
14
15/// A node in a GP tree
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
17#[serde(bound = "")]
18pub enum TreeNode<T: Terminal, F: Function> {
19    /// Terminal node (leaf)
20    Terminal(T),
21    /// Function node (internal)
22    Function(F, Vec<TreeNode<T, F>>),
23}
24
25impl<T: Terminal, F: Function> TreeNode<T, F> {
26    /// Create a new terminal node
27    pub fn terminal(value: T) -> Self {
28        Self::Terminal(value)
29    }
30
31    /// Create a new function node
32    pub fn function(func: F, children: Vec<Self>) -> Self {
33        Self::Function(func, children)
34    }
35
36    /// Check if this is a terminal node
37    pub fn is_terminal(&self) -> bool {
38        matches!(self, Self::Terminal(_))
39    }
40
41    /// Check if this is a function node
42    pub fn is_function(&self) -> bool {
43        matches!(self, Self::Function(_, _))
44    }
45
46    /// Get the depth of this subtree
47    pub fn depth(&self) -> usize {
48        match self {
49            Self::Terminal(_) => 1,
50            Self::Function(_, children) => {
51                1 + children.iter().map(|c| c.depth()).max().unwrap_or(0)
52            }
53        }
54    }
55
56    /// Get the number of nodes in this subtree
57    pub fn size(&self) -> usize {
58        match self {
59            Self::Terminal(_) => 1,
60            Self::Function(_, children) => 1 + children.iter().map(|c| c.size()).sum::<usize>(),
61        }
62    }
63
64    /// Get all node positions (preorder traversal indices)
65    pub fn positions(&self) -> Vec<Vec<usize>> {
66        let mut positions = Vec::new();
67        self.collect_positions(&[], &mut positions);
68        positions
69    }
70
71    fn collect_positions(&self, path: &[usize], positions: &mut Vec<Vec<usize>>) {
72        positions.push(path.to_vec());
73        if let Self::Function(_, children) = self {
74            for (i, child) in children.iter().enumerate() {
75                let mut child_path = path.to_vec();
76                child_path.push(i);
77                child.collect_positions(&child_path, positions);
78            }
79        }
80    }
81
82    /// Get a subtree at the given path
83    pub fn get_subtree(&self, path: &[usize]) -> Option<&Self> {
84        if path.is_empty() {
85            return Some(self);
86        }
87
88        if let Self::Function(_, children) = self {
89            let idx = path[0];
90            if idx < children.len() {
91                children[idx].get_subtree(&path[1..])
92            } else {
93                None
94            }
95        } else {
96            None
97        }
98    }
99
100    /// Get a mutable subtree at the given path
101    pub fn get_subtree_mut(&mut self, path: &[usize]) -> Option<&mut Self> {
102        if path.is_empty() {
103            return Some(self);
104        }
105
106        if let Self::Function(_, children) = self {
107            let idx = path[0];
108            if idx < children.len() {
109                children[idx].get_subtree_mut(&path[1..])
110            } else {
111                None
112            }
113        } else {
114            None
115        }
116    }
117
118    /// Replace a subtree at the given path
119    pub fn replace_subtree(&mut self, path: &[usize], new_subtree: Self) -> bool {
120        if path.is_empty() {
121            *self = new_subtree;
122            return true;
123        }
124
125        if let Self::Function(_, children) = self {
126            let idx = path[0];
127            if idx < children.len() {
128                if path.len() == 1 {
129                    children[idx] = new_subtree;
130                    true
131                } else {
132                    children[idx].replace_subtree(&path[1..], new_subtree)
133                }
134            } else {
135                false
136            }
137        } else {
138            false
139        }
140    }
141
142    /// Get all terminal positions
143    pub fn terminal_positions(&self) -> Vec<Vec<usize>> {
144        let mut positions = Vec::new();
145        self.collect_terminal_positions(&[], &mut positions);
146        positions
147    }
148
149    fn collect_terminal_positions(&self, path: &[usize], positions: &mut Vec<Vec<usize>>) {
150        match self {
151            Self::Terminal(_) => positions.push(path.to_vec()),
152            Self::Function(_, children) => {
153                for (i, child) in children.iter().enumerate() {
154                    let mut child_path = path.to_vec();
155                    child_path.push(i);
156                    child.collect_terminal_positions(&child_path, positions);
157                }
158            }
159        }
160    }
161
162    /// Get all function positions
163    pub fn function_positions(&self) -> Vec<Vec<usize>> {
164        let mut positions = Vec::new();
165        self.collect_function_positions(&[], &mut positions);
166        positions
167    }
168
169    fn collect_function_positions(&self, path: &[usize], positions: &mut Vec<Vec<usize>>) {
170        if let Self::Function(_, children) = self {
171            positions.push(path.to_vec());
172            for (i, child) in children.iter().enumerate() {
173                let mut child_path = path.to_vec();
174                child_path.push(i);
175                child.collect_function_positions(&child_path, positions);
176            }
177        }
178    }
179}
180
181/// Trait for terminal nodes in GP trees
182pub trait Terminal:
183    Clone + Send + Sync + PartialEq + fmt::Debug + Serialize + for<'de> Deserialize<'de> + 'static
184{
185    /// Generate a random terminal
186    fn random<R: Rng>(rng: &mut R) -> Self;
187
188    /// Get the set of available terminals
189    fn terminals() -> &'static [Self];
190
191    /// Evaluate this terminal with the given variable bindings
192    fn evaluate(&self, variables: &[f64]) -> f64;
193
194    /// Convert to string representation
195    fn to_string(&self) -> String;
196}
197
198/// Trait for function nodes in GP trees
199pub trait Function:
200    Clone + Send + Sync + PartialEq + fmt::Debug + Serialize + for<'de> Deserialize<'de> + 'static
201{
202    /// Get the arity (number of arguments) of this function
203    fn arity(&self) -> usize;
204
205    /// Generate a random function
206    fn random<R: Rng>(rng: &mut R) -> Self;
207
208    /// Get the set of available functions
209    fn functions() -> &'static [Self];
210
211    /// Apply this function to the given arguments
212    fn apply(&self, args: &[f64]) -> f64;
213
214    /// Convert to string representation
215    fn to_string(&self) -> String;
216}
217
218/// Standard arithmetic terminals for symbolic regression
219#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
220pub enum ArithmeticTerminal {
221    /// Variable x_i
222    Variable(usize),
223    /// Constant value
224    Constant(f64),
225    /// Ephemeral random constant (ERC)
226    Erc(f64),
227}
228
229impl Terminal for ArithmeticTerminal {
230    fn random<R: Rng>(rng: &mut R) -> Self {
231        let choice: u8 = rng.gen_range(0..3);
232        match choice {
233            0 => Self::Variable(rng.gen_range(0..10)),
234            1 => Self::Constant(rng.gen_range(-10.0..10.0)),
235            _ => Self::Erc(rng.gen_range(-1.0..1.0)),
236        }
237    }
238
239    fn terminals() -> &'static [Self] {
240        // Return a representative set; actual terminals depend on context
241        &[]
242    }
243
244    fn evaluate(&self, variables: &[f64]) -> f64 {
245        match self {
246            Self::Variable(i) => variables.get(*i).copied().unwrap_or(0.0),
247            Self::Constant(c) | Self::Erc(c) => *c,
248        }
249    }
250
251    fn to_string(&self) -> String {
252        match self {
253            Self::Variable(i) => format!("x{}", i),
254            Self::Constant(c) | Self::Erc(c) => format!("{:.4}", c),
255        }
256    }
257}
258
259/// Standard arithmetic functions for symbolic regression
260#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
261pub enum ArithmeticFunction {
262    /// Addition
263    Add,
264    /// Subtraction
265    Sub,
266    /// Multiplication
267    Mul,
268    /// Protected division (returns 1.0 for division by zero)
269    Div,
270    /// Sine
271    Sin,
272    /// Cosine
273    Cos,
274    /// Exponential
275    Exp,
276    /// Natural logarithm (protected)
277    Log,
278    /// Square root (protected)
279    Sqrt,
280    /// Power
281    Pow,
282    /// Negation (unary)
283    Neg,
284    /// Absolute value (unary)
285    Abs,
286}
287
288impl Function for ArithmeticFunction {
289    fn arity(&self) -> usize {
290        match self {
291            Self::Add | Self::Sub | Self::Mul | Self::Div | Self::Pow => 2,
292            Self::Sin | Self::Cos | Self::Exp | Self::Log | Self::Sqrt | Self::Neg | Self::Abs => 1,
293        }
294    }
295
296    fn random<R: Rng>(rng: &mut R) -> Self {
297        let funcs = Self::functions();
298        funcs[rng.gen_range(0..funcs.len())].clone()
299    }
300
301    fn functions() -> &'static [Self] {
302        &[
303            Self::Add,
304            Self::Sub,
305            Self::Mul,
306            Self::Div,
307            Self::Sin,
308            Self::Cos,
309            Self::Exp,
310            Self::Log,
311            Self::Sqrt,
312            Self::Neg,
313            Self::Abs,
314        ]
315    }
316
317    fn apply(&self, args: &[f64]) -> f64 {
318        match self {
319            Self::Add => args.get(0).unwrap_or(&0.0) + args.get(1).unwrap_or(&0.0),
320            Self::Sub => args.get(0).unwrap_or(&0.0) - args.get(1).unwrap_or(&0.0),
321            Self::Mul => args.get(0).unwrap_or(&1.0) * args.get(1).unwrap_or(&1.0),
322            Self::Div => {
323                let a = args.get(0).unwrap_or(&0.0);
324                let b = args.get(1).unwrap_or(&1.0);
325                if b.abs() < 1e-10 {
326                    1.0 // Protected division
327                } else {
328                    a / b
329                }
330            }
331            Self::Sin => args.get(0).unwrap_or(&0.0).sin(),
332            Self::Cos => args.get(0).unwrap_or(&0.0).cos(),
333            Self::Exp => {
334                let x = args.get(0).unwrap_or(&0.0);
335                if *x > 700.0 {
336                    f64::MAX // Overflow protection
337                } else {
338                    x.exp()
339                }
340            }
341            Self::Log => {
342                let x = args.get(0).unwrap_or(&1.0);
343                if *x <= 0.0 {
344                    0.0 // Protected log
345                } else {
346                    x.ln()
347                }
348            }
349            Self::Sqrt => {
350                let x = args.get(0).unwrap_or(&0.0);
351                if *x < 0.0 {
352                    (-x).sqrt() // Protected sqrt
353                } else {
354                    x.sqrt()
355                }
356            }
357            Self::Pow => {
358                let base = args.get(0).unwrap_or(&1.0);
359                let exp = args.get(1).unwrap_or(&1.0);
360                // Protected power
361                if base.abs() < 1e-10 && *exp < 0.0 {
362                    0.0
363                } else {
364                    base.powf(*exp).clamp(-1e10, 1e10)
365                }
366            }
367            Self::Neg => -args.get(0).unwrap_or(&0.0),
368            Self::Abs => args.get(0).unwrap_or(&0.0).abs(),
369        }
370    }
371
372    fn to_string(&self) -> String {
373        match self {
374            Self::Add => "+".to_string(),
375            Self::Sub => "-".to_string(),
376            Self::Mul => "*".to_string(),
377            Self::Div => "/".to_string(),
378            Self::Sin => "sin".to_string(),
379            Self::Cos => "cos".to_string(),
380            Self::Exp => "exp".to_string(),
381            Self::Log => "log".to_string(),
382            Self::Sqrt => "sqrt".to_string(),
383            Self::Pow => "pow".to_string(),
384            Self::Neg => "neg".to_string(),
385            Self::Abs => "abs".to_string(),
386        }
387    }
388}
389
390/// Tree genome for genetic programming
391#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
392#[serde(bound = "")]
393pub struct TreeGenome<T: Terminal = ArithmeticTerminal, F: Function = ArithmeticFunction> {
394    /// Root node of the tree
395    pub root: TreeNode<T, F>,
396    /// Maximum allowed depth
397    pub max_depth: usize,
398}
399
400impl<T: Terminal, F: Function> TreeGenome<T, F> {
401    /// Create a new tree genome
402    pub fn new(root: TreeNode<T, F>, max_depth: usize) -> Self {
403        Self { root, max_depth }
404    }
405
406    /// Get the depth of the tree
407    pub fn depth(&self) -> usize {
408        self.root.depth()
409    }
410
411    /// Get the number of nodes in the tree
412    pub fn size(&self) -> usize {
413        self.root.size()
414    }
415
416    /// Evaluate the tree with given variable bindings
417    pub fn evaluate(&self, variables: &[f64]) -> f64 {
418        self.evaluate_node(&self.root, variables)
419    }
420
421    fn evaluate_node(&self, node: &TreeNode<T, F>, variables: &[f64]) -> f64 {
422        match node {
423            TreeNode::Terminal(t) => t.evaluate(variables),
424            TreeNode::Function(f, children) => {
425                let args: Vec<f64> = children
426                    .iter()
427                    .map(|c| self.evaluate_node(c, variables))
428                    .collect();
429                f.apply(&args)
430            }
431        }
432    }
433
434    /// Generate a random tree using the "full" method
435    pub fn generate_full<R: Rng>(rng: &mut R, depth: usize, max_depth: usize) -> Self {
436        let root = Self::generate_full_node(rng, depth, 0);
437        Self { root, max_depth }
438    }
439
440    fn generate_full_node<R: Rng>(
441        rng: &mut R,
442        target_depth: usize,
443        current_depth: usize,
444    ) -> TreeNode<T, F> {
445        if current_depth >= target_depth {
446            TreeNode::Terminal(T::random(rng))
447        } else {
448            let func = F::random(rng);
449            let arity = func.arity();
450            let children: Vec<TreeNode<T, F>> = (0..arity)
451                .map(|_| Self::generate_full_node(rng, target_depth, current_depth + 1))
452                .collect();
453            TreeNode::Function(func, children)
454        }
455    }
456
457    /// Generate a random tree using the "grow" method
458    pub fn generate_grow<R: Rng>(rng: &mut R, max_depth: usize, terminal_prob: f64) -> Self {
459        let root = Self::generate_grow_node(rng, max_depth, 0, terminal_prob);
460        Self { root, max_depth }
461    }
462
463    fn generate_grow_node<R: Rng>(
464        rng: &mut R,
465        max_depth: usize,
466        current_depth: usize,
467        terminal_prob: f64,
468    ) -> TreeNode<T, F> {
469        if current_depth >= max_depth {
470            TreeNode::Terminal(T::random(rng))
471        } else if rng.gen::<f64>() < terminal_prob {
472            TreeNode::Terminal(T::random(rng))
473        } else {
474            let func = F::random(rng);
475            let arity = func.arity();
476            let children: Vec<TreeNode<T, F>> = (0..arity)
477                .map(|_| Self::generate_grow_node(rng, max_depth, current_depth + 1, terminal_prob))
478                .collect();
479            TreeNode::Function(func, children)
480        }
481    }
482
483    /// Generate using ramped half-and-half
484    pub fn generate_ramped_half_and_half<R: Rng>(
485        rng: &mut R,
486        min_depth: usize,
487        max_depth: usize,
488    ) -> Self {
489        let depth = rng.gen_range(min_depth..=max_depth);
490        if rng.gen() {
491            Self::generate_full(rng, depth, max_depth)
492        } else {
493            Self::generate_grow(rng, depth, 0.3)
494        }
495    }
496
497    /// Convert tree to S-expression string
498    pub fn to_sexpr(&self) -> String {
499        self.node_to_sexpr(&self.root)
500    }
501
502    fn node_to_sexpr(&self, node: &TreeNode<T, F>) -> String {
503        match node {
504            TreeNode::Terminal(t) => t.to_string(),
505            TreeNode::Function(f, children) => {
506                let child_strs: Vec<String> =
507                    children.iter().map(|c| self.node_to_sexpr(c)).collect();
508                format!("({} {})", f.to_string(), child_strs.join(" "))
509            }
510        }
511    }
512
513    /// Get a random node position
514    pub fn random_position<R: Rng>(&self, rng: &mut R) -> Vec<usize> {
515        let positions = self.root.positions();
516        positions[rng.gen_range(0..positions.len())].clone()
517    }
518
519    /// Get a random terminal position
520    pub fn random_terminal_position<R: Rng>(&self, rng: &mut R) -> Option<Vec<usize>> {
521        let positions = self.root.terminal_positions();
522        if positions.is_empty() {
523            None
524        } else {
525            Some(positions[rng.gen_range(0..positions.len())].clone())
526        }
527    }
528
529    /// Get a random function position
530    pub fn random_function_position<R: Rng>(&self, rng: &mut R) -> Option<Vec<usize>> {
531        let positions = self.root.function_positions();
532        if positions.is_empty() {
533            None
534        } else {
535            Some(positions[rng.gen_range(0..positions.len())].clone())
536        }
537    }
538}
539
540impl<T: Terminal, F: Function> EvolutionaryGenome for TreeGenome<T, F> {
541    type Allele = TreeNode<T, F>;
542    type Phenotype = Self;
543
544    fn to_trace(&self) -> Trace {
545        let mut trace = Trace::default();
546        let mut index = 0;
547        self.node_to_trace(&self.root, &mut trace, &mut index);
548        // Store max_depth and total size
549        trace.insert_choice(
550            addr!("tree_max_depth"),
551            ChoiceValue::Usize(self.max_depth),
552            0.0,
553        );
554        trace.insert_choice(addr!("tree_size"), ChoiceValue::Usize(index), 0.0);
555        trace
556    }
557
558    fn from_trace(trace: &Trace) -> Result<Self, GenomeError> {
559        let max_depth = trace
560            .get_usize(&addr!("tree_max_depth"))
561            .ok_or_else(|| GenomeError::MissingAddress("tree_max_depth".to_string()))?;
562
563        let mut index = 0;
564        let root = Self::node_from_trace(trace, &mut index)?;
565        Ok(Self { root, max_depth })
566    }
567
568    fn decode(&self) -> Self::Phenotype {
569        self.clone()
570    }
571
572    fn dimension(&self) -> usize {
573        self.size()
574    }
575
576    fn generate<R: Rng>(rng: &mut R, bounds: &MultiBounds) -> Self {
577        let max_depth = bounds.dimension().max(3).min(10);
578        Self::generate_ramped_half_and_half(rng, 2, max_depth)
579    }
580
581    fn distance(&self, other: &Self) -> f64 {
582        // Tree edit distance approximation based on size difference
583        let size_diff = (self.size() as f64 - other.size() as f64).abs();
584        let depth_diff = (self.depth() as f64 - other.depth() as f64).abs();
585        size_diff + depth_diff
586    }
587
588    fn trace_prefix() -> &'static str {
589        "tree"
590    }
591}
592
593impl<T: Terminal, F: Function> TreeGenome<T, F> {
594    fn node_to_trace(&self, node: &TreeNode<T, F>, trace: &mut Trace, index: &mut usize) {
595        let current_index = *index;
596        *index += 1;
597
598        match node {
599            TreeNode::Terminal(t) => {
600                // Store is_terminal flag (true = terminal)
601                trace.insert_choice(
602                    addr!("tree_is_terminal", current_index),
603                    ChoiceValue::Bool(true),
604                    0.0,
605                );
606                // For ArithmeticTerminal, store the variant type and value
607                // We encode using f64 for simplicity
608                let (term_type, term_val) = Self::encode_terminal(t);
609                trace.insert_choice(
610                    addr!("tree_term_type", current_index),
611                    ChoiceValue::F64(term_type),
612                    0.0,
613                );
614                trace.insert_choice(
615                    addr!("tree_term_val", current_index),
616                    ChoiceValue::F64(term_val),
617                    0.0,
618                );
619            }
620            TreeNode::Function(f, children) => {
621                // Store is_terminal flag (false = function)
622                trace.insert_choice(
623                    addr!("tree_is_terminal", current_index),
624                    ChoiceValue::Bool(false),
625                    0.0,
626                );
627                // Store function type as index and arity
628                let func_idx = Self::encode_function(f);
629                trace.insert_choice(
630                    addr!("tree_func_idx", current_index),
631                    ChoiceValue::Usize(func_idx),
632                    0.0,
633                );
634                trace.insert_choice(
635                    addr!("tree_arity", current_index),
636                    ChoiceValue::Usize(children.len()),
637                    0.0,
638                );
639                // Recurse into children
640                for child in children {
641                    self.node_to_trace(child, trace, index);
642                }
643            }
644        }
645    }
646
647    fn node_from_trace(trace: &Trace, index: &mut usize) -> Result<TreeNode<T, F>, GenomeError> {
648        let current_index = *index;
649        *index += 1;
650
651        let is_terminal = trace
652            .get_bool(&addr!("tree_is_terminal", current_index))
653            .ok_or_else(|| {
654                GenomeError::MissingAddress(format!("tree_is_terminal#{}", current_index))
655            })?;
656
657        if is_terminal {
658            let term_type = trace
659                .get_f64(&addr!("tree_term_type", current_index))
660                .ok_or_else(|| {
661                    GenomeError::MissingAddress(format!("tree_term_type#{}", current_index))
662                })?;
663            let term_val = trace
664                .get_f64(&addr!("tree_term_val", current_index))
665                .ok_or_else(|| {
666                    GenomeError::MissingAddress(format!("tree_term_val#{}", current_index))
667                })?;
668
669            let terminal = Self::decode_terminal(term_type, term_val)?;
670            Ok(TreeNode::Terminal(terminal))
671        } else {
672            let func_idx = trace
673                .get_usize(&addr!("tree_func_idx", current_index))
674                .ok_or_else(|| {
675                    GenomeError::MissingAddress(format!("tree_func_idx#{}", current_index))
676                })?;
677            let arity = trace
678                .get_usize(&addr!("tree_arity", current_index))
679                .ok_or_else(|| {
680                    GenomeError::MissingAddress(format!("tree_arity#{}", current_index))
681                })?;
682
683            let func = Self::decode_function(func_idx)?;
684            let mut children = Vec::with_capacity(arity);
685            for _ in 0..arity {
686                children.push(Self::node_from_trace(trace, index)?);
687            }
688            Ok(TreeNode::Function(func, children))
689        }
690    }
691
692    // Encode terminal as (type_code, value)
693    // type_code: 0 = Variable, 1 = Constant, 2 = ERC
694    fn encode_terminal(_terminal: &T) -> (f64, f64) {
695        // Default implementation for generic terminals
696        // Concrete implementations would need specialization
697        (0.0, 0.0)
698    }
699
700    fn decode_terminal(term_type: f64, term_val: f64) -> Result<T, GenomeError> {
701        // This requires runtime generation; use random for now
702        // A full implementation would need type-specific decoding
703        let mut rng = rand::thread_rng();
704        let _ = (term_type, term_val); // Acknowledge parameters
705        Ok(T::random(&mut rng))
706    }
707
708    fn encode_function(_func: &F) -> usize {
709        // Default returns 0; specialized for ArithmeticFunction
710        0
711    }
712
713    fn decode_function(func_idx: usize) -> Result<F, GenomeError> {
714        let funcs = F::functions();
715        if func_idx < funcs.len() {
716            Ok(funcs[func_idx].clone())
717        } else {
718            // Fall back to random
719            let mut rng = rand::thread_rng();
720            Ok(F::random(&mut rng))
721        }
722    }
723}
724
725impl<T: Terminal, F: Function> fmt::Display for TreeGenome<T, F> {
726    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
727        write!(f, "{}", self.to_sexpr())
728    }
729}
730
731/// Trait for tree genome types (marker trait for operators)
732pub trait TreeGenomeType: EvolutionaryGenome {
733    /// The terminal type
734    type Term: Terminal;
735    /// The function type
736    type Func: Function;
737
738    /// Get the root of the tree
739    fn root(&self) -> &TreeNode<Self::Term, Self::Func>;
740
741    /// Get a mutable reference to the root
742    fn root_mut(&mut self) -> &mut TreeNode<Self::Term, Self::Func>;
743
744    /// Get the maximum depth
745    fn max_depth(&self) -> usize;
746
747    /// Create a new tree from a root node
748    fn from_root(root: TreeNode<Self::Term, Self::Func>, max_depth: usize) -> Self;
749}
750
751impl<T: Terminal, F: Function> TreeGenomeType for TreeGenome<T, F> {
752    type Term = T;
753    type Func = F;
754
755    fn root(&self) -> &TreeNode<T, F> {
756        &self.root
757    }
758
759    fn root_mut(&mut self) -> &mut TreeNode<T, F> {
760        &mut self.root
761    }
762
763    fn max_depth(&self) -> usize {
764        self.max_depth
765    }
766
767    fn from_root(root: TreeNode<T, F>, max_depth: usize) -> Self {
768        Self { root, max_depth }
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775
776    #[test]
777    fn test_tree_node_terminal() {
778        let node: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
779            TreeNode::terminal(ArithmeticTerminal::Variable(0));
780        assert!(node.is_terminal());
781        assert!(!node.is_function());
782        assert_eq!(node.depth(), 1);
783        assert_eq!(node.size(), 1);
784    }
785
786    #[test]
787    fn test_tree_node_function() {
788        let left = TreeNode::terminal(ArithmeticTerminal::Variable(0));
789        let right = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
790        let node = TreeNode::function(ArithmeticFunction::Add, vec![left, right]);
791
792        assert!(!node.is_terminal());
793        assert!(node.is_function());
794        assert_eq!(node.depth(), 2);
795        assert_eq!(node.size(), 3);
796    }
797
798    #[test]
799    fn test_tree_node_positions() {
800        // Create: (+ x0 (* 1.0 x1))
801        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
802        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
803        let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
804        let mul = TreeNode::function(ArithmeticFunction::Mul, vec![c1, x1]);
805        let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, mul]);
806
807        let positions = add.positions();
808        assert_eq!(positions.len(), 5); // root, left, right, right-left, right-right
809        assert!(positions.contains(&vec![])); // root
810        assert!(positions.contains(&vec![0])); // left child (x0)
811        assert!(positions.contains(&vec![1])); // right child (mul)
812        assert!(positions.contains(&vec![1, 0])); // mul's left child
813        assert!(positions.contains(&vec![1, 1])); // mul's right child
814    }
815
816    #[test]
817    fn test_tree_node_get_subtree() {
818        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
819        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
820        let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
821            TreeNode::function(ArithmeticFunction::Add, vec![x0.clone(), c1]);
822
823        assert_eq!(add.get_subtree(&[0]), Some(&x0));
824        assert!(add.get_subtree(&[2]).is_none());
825    }
826
827    #[test]
828    fn test_tree_genome_evaluate() {
829        // Create: (+ x0 x1)
830        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
831        let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
832        let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, x1]);
833        let tree = TreeGenome::new(add, 5);
834
835        assert_eq!(tree.evaluate(&[3.0, 4.0]), 7.0);
836    }
837
838    #[test]
839    fn test_tree_genome_evaluate_complex() {
840        // Create: (* (+ x0 1) x1) = (x0 + 1) * x1
841        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
842        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
843        let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
844        let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
845        let mul = TreeNode::function(ArithmeticFunction::Mul, vec![add, x1]);
846        let tree = TreeGenome::new(mul, 5);
847
848        assert_eq!(tree.evaluate(&[2.0, 3.0]), 9.0); // (2 + 1) * 3 = 9
849    }
850
851    #[test]
852    fn test_tree_genome_generate_full() {
853        let mut rng = rand::thread_rng();
854        let tree: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
855            TreeGenome::generate_full(&mut rng, 3, 5);
856
857        // Full tree with target depth 3 creates: Function -> Function -> Function -> Terminal
858        // Which has depth 4 (counting levels from root to leaf)
859        assert!(tree.depth() >= 3);
860        assert!(tree.size() >= 1);
861    }
862
863    #[test]
864    fn test_tree_genome_generate_grow() {
865        let mut rng = rand::thread_rng();
866        let tree: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
867            TreeGenome::generate_grow(&mut rng, 5, 0.3);
868
869        // Grow can create trees up to max_depth + 1 levels (due to counting from 0)
870        assert!(tree.depth() <= 6);
871        assert!(tree.size() >= 1);
872    }
873
874    #[test]
875    fn test_tree_genome_to_sexpr() {
876        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
877        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
878        let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
879            TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
880        let tree = TreeGenome::new(add, 5);
881
882        let sexpr = tree.to_sexpr();
883        assert!(sexpr.contains('+'));
884        assert!(sexpr.contains("x0"));
885        assert!(sexpr.contains("1.0"));
886    }
887
888    #[test]
889    fn test_tree_genome_trace_roundtrip() {
890        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
891        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(2.5));
892        let add = TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
893        let original: TreeGenome<ArithmeticTerminal, ArithmeticFunction> = TreeGenome::new(add, 5);
894
895        let trace = original.to_trace();
896        let recovered: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
897            TreeGenome::from_trace(&trace).unwrap();
898
899        // The trace encoding preserves structure (max_depth, size) but terminal values
900        // are generated fresh since the simple encoding doesn't preserve all details
901        assert_eq!(original.max_depth, recovered.max_depth);
902        assert_eq!(original.size(), recovered.size());
903        // Both should be valid trees of the same structure
904        assert!(recovered.evaluate(&[3.0]).is_finite());
905    }
906
907    #[test]
908    fn test_tree_node_replace_subtree() {
909        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
910        let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
911        let mut add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
912            TreeNode::function(ArithmeticFunction::Add, vec![x0, x1]);
913
914        let c5 = TreeNode::terminal(ArithmeticTerminal::Constant(5.0));
915        add.replace_subtree(&[0], c5);
916
917        // Now tree should be (+ 5.0 x1)
918        let tree = TreeGenome::new(add, 5);
919        assert_eq!(tree.evaluate(&[0.0, 3.0]), 8.0); // 5 + 3 = 8
920    }
921
922    #[test]
923    fn test_arithmetic_function_protected_div() {
924        assert_eq!(ArithmeticFunction::Div.apply(&[1.0, 0.0]), 1.0);
925        assert_eq!(ArithmeticFunction::Div.apply(&[6.0, 2.0]), 3.0);
926    }
927
928    #[test]
929    fn test_arithmetic_function_protected_log() {
930        assert_eq!(ArithmeticFunction::Log.apply(&[-1.0]), 0.0);
931        assert!((ArithmeticFunction::Log.apply(&[std::f64::consts::E]) - 1.0).abs() < 0.001);
932    }
933
934    #[test]
935    fn test_arithmetic_function_protected_sqrt() {
936        assert_eq!(ArithmeticFunction::Sqrt.apply(&[4.0]), 2.0);
937        assert_eq!(ArithmeticFunction::Sqrt.apply(&[-4.0]), 2.0); // Protected
938    }
939
940    #[test]
941    fn test_tree_genome_evolutionary_genome_trait() {
942        let mut rng = rand::thread_rng();
943        let bounds = MultiBounds::symmetric(5.0, 5);
944        let tree: TreeGenome<ArithmeticTerminal, ArithmeticFunction> =
945            TreeGenome::generate(&mut rng, &bounds);
946
947        assert!(tree.dimension() >= 1);
948        let decoded = tree.decode();
949        assert_eq!(decoded.size(), tree.size());
950    }
951
952    #[test]
953    fn test_tree_terminal_and_function_positions() {
954        // Create: (+ x0 (* 1.0 x1))
955        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
956        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
957        let x1 = TreeNode::terminal(ArithmeticTerminal::Variable(1));
958        let mul = TreeNode::function(ArithmeticFunction::Mul, vec![c1, x1]);
959        let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
960            TreeNode::function(ArithmeticFunction::Add, vec![x0, mul]);
961
962        let terminal_positions = add.terminal_positions();
963        assert_eq!(terminal_positions.len(), 3); // x0, 1.0, x1
964
965        let function_positions = add.function_positions();
966        assert_eq!(function_positions.len(), 2); // add, mul
967    }
968
969    #[test]
970    fn test_tree_genome_display() {
971        let x0 = TreeNode::terminal(ArithmeticTerminal::Variable(0));
972        let c1 = TreeNode::terminal(ArithmeticTerminal::Constant(1.0));
973        let add: TreeNode<ArithmeticTerminal, ArithmeticFunction> =
974            TreeNode::function(ArithmeticFunction::Add, vec![x0, c1]);
975        let tree = TreeGenome::new(add, 5);
976
977        let display = format!("{}", tree);
978        assert!(!display.is_empty());
979    }
980}