Skip to main content

fugue_evo/fugue_integration/
trace_operators.rs

1//! Trace-based genetic operators
2//!
3//! These operators work directly on Fugue traces, enabling probabilistic
4//! interpretations of mutation and crossover.
5
6use std::collections::HashSet;
7
8use fugue::{Address, ChoiceValue, Trace};
9use rand::Rng;
10
11use crate::error::GenomeError;
12use crate::genome::traits::EvolutionaryGenome;
13
14/// Trait for selecting which addresses to mutate
15pub trait MutationSelector: Send + Sync {
16    /// Select addresses that should be mutated
17    fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address>;
18}
19
20/// Uniform random mutation selector
21///
22/// Each address has an independent probability of being selected for mutation.
23#[derive(Clone, Debug)]
24pub struct UniformMutationSelector {
25    /// Probability of mutating each address
26    pub mutation_probability: f64,
27}
28
29impl UniformMutationSelector {
30    /// Create a new uniform mutation selector
31    pub fn new(probability: f64) -> Self {
32        Self {
33            mutation_probability: probability.clamp(0.0, 1.0),
34        }
35    }
36
37    /// Default 1/n mutation probability
38    pub fn one_over_n(n: usize) -> Self {
39        Self::new(1.0 / n as f64)
40    }
41}
42
43impl MutationSelector for UniformMutationSelector {
44    fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address> {
45        trace
46            .choices
47            .keys()
48            .filter(|_| rng.gen::<f64>() < self.mutation_probability)
49            .cloned()
50            .collect()
51    }
52}
53
54/// Single-site mutation selector
55///
56/// Selects exactly one random address for mutation.
57#[derive(Clone, Debug, Default)]
58pub struct SingleSiteMutationSelector;
59
60impl SingleSiteMutationSelector {
61    /// Create a new single-site selector
62    pub fn new() -> Self {
63        Self
64    }
65}
66
67impl MutationSelector for SingleSiteMutationSelector {
68    fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address> {
69        let addresses: Vec<_> = trace.choices.keys().collect();
70        if addresses.is_empty() {
71            return HashSet::new();
72        }
73
74        let idx = rng.gen_range(0..addresses.len());
75        let mut sites = HashSet::new();
76        sites.insert(addresses[idx].clone());
77        sites
78    }
79}
80
81/// Multi-site mutation selector
82///
83/// Selects exactly k random addresses for mutation.
84#[derive(Clone, Debug)]
85pub struct MultiSiteMutationSelector {
86    /// Number of sites to mutate
87    pub num_sites: usize,
88}
89
90impl MultiSiteMutationSelector {
91    /// Create a new multi-site selector
92    pub fn new(num_sites: usize) -> Self {
93        Self { num_sites }
94    }
95}
96
97impl MutationSelector for MultiSiteMutationSelector {
98    fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address> {
99        let addresses: Vec<_> = trace.choices.keys().collect();
100        if addresses.is_empty() {
101            return HashSet::new();
102        }
103
104        let k = self.num_sites.min(addresses.len());
105        let mut selected = HashSet::new();
106        let mut indices: Vec<usize> = (0..addresses.len()).collect();
107
108        // Fisher-Yates partial shuffle
109        for i in 0..k {
110            let j = rng.gen_range(i..addresses.len());
111            indices.swap(i, j);
112            selected.insert(addresses[indices[i]].clone());
113        }
114
115        selected
116    }
117}
118
119/// Trait for determining which parent contributes to each address during crossover
120pub trait CrossoverMask: Send + Sync {
121    /// Returns true if parent1's value should be used at the given address
122    fn from_parent1(&self, addr: &Address) -> bool;
123}
124
125/// Uniform crossover mask
126///
127/// Each address independently chosen from either parent.
128#[derive(Clone, Debug)]
129pub struct UniformCrossoverMask {
130    /// Probability of choosing parent1's value
131    pub bias: f64,
132    /// Set of addresses that should come from parent1
133    selected: HashSet<Address>,
134}
135
136impl UniformCrossoverMask {
137    /// Create a new uniform crossover mask
138    pub fn new<R: Rng>(bias: f64, trace: &Trace, rng: &mut R) -> Self {
139        let selected = trace
140            .choices
141            .keys()
142            .filter(|_| rng.gen::<f64>() < bias)
143            .cloned()
144            .collect();
145
146        Self { bias, selected }
147    }
148
149    /// Create with 50/50 probability
150    pub fn balanced<R: Rng>(trace: &Trace, rng: &mut R) -> Self {
151        Self::new(0.5, trace, rng)
152    }
153}
154
155impl CrossoverMask for UniformCrossoverMask {
156    fn from_parent1(&self, addr: &Address) -> bool {
157        self.selected.contains(addr)
158    }
159}
160
161/// Single-point crossover mask
162///
163/// All addresses before the crossover point come from parent1,
164/// all after come from parent2.
165#[derive(Clone, Debug)]
166pub struct SinglePointCrossoverMask {
167    /// Addresses from parent1 (before crossover point)
168    parent1_addresses: HashSet<Address>,
169}
170
171impl SinglePointCrossoverMask {
172    /// Create a new single-point crossover mask
173    pub fn new<R: Rng>(trace: &Trace, rng: &mut R) -> Self {
174        let addresses: Vec<_> = trace.choices.keys().cloned().collect();
175        if addresses.is_empty() {
176            return Self {
177                parent1_addresses: HashSet::new(),
178            };
179        }
180
181        let crossover_point = rng.gen_range(0..=addresses.len());
182        let parent1_addresses: HashSet<_> = addresses.into_iter().take(crossover_point).collect();
183
184        Self { parent1_addresses }
185    }
186}
187
188impl CrossoverMask for SinglePointCrossoverMask {
189    fn from_parent1(&self, addr: &Address) -> bool {
190        self.parent1_addresses.contains(addr)
191    }
192}
193
194/// Two-point crossover mask
195///
196/// Addresses between the two points come from parent2,
197/// outside comes from parent1.
198#[derive(Clone, Debug)]
199pub struct TwoPointCrossoverMask {
200    /// Addresses from parent1 (outside crossover segment)
201    parent1_addresses: HashSet<Address>,
202}
203
204impl TwoPointCrossoverMask {
205    /// Create a new two-point crossover mask
206    pub fn new<R: Rng>(trace: &Trace, rng: &mut R) -> Self {
207        let addresses: Vec<_> = trace.choices.keys().cloned().collect();
208        if addresses.is_empty() {
209            return Self {
210                parent1_addresses: HashSet::new(),
211            };
212        }
213
214        let mut point1 = rng.gen_range(0..addresses.len());
215        let mut point2 = rng.gen_range(0..addresses.len());
216        if point1 > point2 {
217            std::mem::swap(&mut point1, &mut point2);
218        }
219
220        // Parent1 gets addresses outside [point1, point2)
221        let parent1_addresses: HashSet<_> = addresses
222            .iter()
223            .enumerate()
224            .filter(|(i, _)| *i < point1 || *i >= point2)
225            .map(|(_, addr)| addr.clone())
226            .collect();
227
228        Self { parent1_addresses }
229    }
230}
231
232impl CrossoverMask for TwoPointCrossoverMask {
233    fn from_parent1(&self, addr: &Address) -> bool {
234        self.parent1_addresses.contains(addr)
235    }
236}
237
238/// Trace-based mutation operator
239///
240/// Mutates a genome by selectively resampling addresses in its trace representation.
241pub fn mutate_trace<G, S, R>(
242    genome: &G,
243    selector: &S,
244    mutation_fn: impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue,
245    rng: &mut R,
246) -> Result<G, GenomeError>
247where
248    G: EvolutionaryGenome,
249    S: MutationSelector,
250    R: Rng,
251{
252    let trace = genome.to_trace();
253    let mutation_sites = selector.select_sites(&trace, rng);
254
255    let mut new_trace = Trace::default();
256
257    for (addr, choice) in &trace.choices {
258        let new_value = if mutation_sites.contains(addr) {
259            mutation_fn(addr, &choice.value, rng)
260        } else {
261            choice.value.clone()
262        };
263        new_trace.insert_choice(addr.clone(), new_value, choice.logp);
264    }
265
266    G::from_trace(&new_trace)
267}
268
269/// Trace-based crossover operator
270///
271/// Creates offspring by merging parent traces according to a crossover mask.
272pub fn crossover_traces<G, M, R>(
273    parent1: &G,
274    parent2: &G,
275    mask: &M,
276    _rng: &mut R,
277) -> Result<(G, G), GenomeError>
278where
279    G: EvolutionaryGenome,
280    M: CrossoverMask,
281    R: Rng,
282{
283    let trace1 = parent1.to_trace();
284    let trace2 = parent2.to_trace();
285
286    let mut child1_trace = Trace::default();
287    let mut child2_trace = Trace::default();
288
289    // Collect all addresses from both parents
290    let all_addresses: HashSet<Address> = trace1
291        .choices
292        .keys()
293        .chain(trace2.choices.keys())
294        .cloned()
295        .collect();
296
297    for addr in all_addresses {
298        let (val_for_child1, val_for_child2) = if mask.from_parent1(&addr) {
299            // Child1 gets parent1, child2 gets parent2
300            (
301                trace1
302                    .choices
303                    .get(&addr)
304                    .map(|c| c.value.clone())
305                    .unwrap_or(ChoiceValue::F64(0.0)),
306                trace2
307                    .choices
308                    .get(&addr)
309                    .map(|c| c.value.clone())
310                    .unwrap_or(ChoiceValue::F64(0.0)),
311            )
312        } else {
313            // Child1 gets parent2, child2 gets parent1
314            (
315                trace2
316                    .choices
317                    .get(&addr)
318                    .map(|c| c.value.clone())
319                    .unwrap_or(ChoiceValue::F64(0.0)),
320                trace1
321                    .choices
322                    .get(&addr)
323                    .map(|c| c.value.clone())
324                    .unwrap_or(ChoiceValue::F64(0.0)),
325            )
326        };
327
328        child1_trace.insert_choice(addr.clone(), val_for_child1, 0.0);
329        child2_trace.insert_choice(addr, val_for_child2, 0.0);
330    }
331
332    let child1 = G::from_trace(&child1_trace)?;
333    let child2 = G::from_trace(&child2_trace)?;
334
335    Ok((child1, child2))
336}
337
338/// Gaussian mutation function for f64 values
339pub fn gaussian_mutation<R: Rng>(
340    sigma: f64,
341) -> impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue {
342    move |_addr, value, rng| {
343        if let ChoiceValue::F64(v) = value {
344            let noise: f64 = rng.gen::<f64>() * 2.0 - 1.0; // Simple uniform noise
345            let mutated = v + sigma * noise * 2.0_f64.sqrt(); // Scale to approximate gaussian
346            ChoiceValue::F64(mutated)
347        } else {
348            value.clone()
349        }
350    }
351}
352
353/// Bit flip mutation function for boolean values
354pub fn bit_flip_mutation<R: Rng>() -> impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue {
355    move |_addr, value, _rng| {
356        if let ChoiceValue::Bool(b) = value {
357            ChoiceValue::Bool(!b)
358        } else {
359            value.clone()
360        }
361    }
362}
363
364/// Bounded mutation function that respects bounds
365pub fn bounded_mutation<R: Rng>(
366    sigma: f64,
367    lower: f64,
368    upper: f64,
369) -> impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue {
370    move |_addr, value, rng| {
371        if let ChoiceValue::F64(v) = value {
372            let noise: f64 = rng.gen::<f64>() * 2.0 - 1.0;
373            let mutated = (v + sigma * noise * 2.0_f64.sqrt()).clamp(lower, upper);
374            ChoiceValue::F64(mutated)
375        } else {
376            value.clone()
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::genome::real_vector::RealVector;
385    use crate::genome::traits::RealValuedGenome;
386
387    #[test]
388    fn test_uniform_mutation_selector() {
389        let mut rng = rand::thread_rng();
390        let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
391        let trace = genome.to_trace();
392
393        // High probability should select many sites
394        let selector = UniformMutationSelector::new(0.9);
395        let sites = selector.select_sites(&trace, &mut rng);
396        // Most likely selects multiple sites (probabilistic, so not deterministic)
397        assert!(sites.len() <= 5);
398
399        // Low probability should select few sites
400        let selector_low = UniformMutationSelector::new(0.1);
401        let _sites_low = selector_low.select_sites(&trace, &mut rng);
402    }
403
404    #[test]
405    fn test_single_site_mutation_selector() {
406        let mut rng = rand::thread_rng();
407        let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
408        let trace = genome.to_trace();
409
410        let selector = SingleSiteMutationSelector::new();
411        let sites = selector.select_sites(&trace, &mut rng);
412
413        assert_eq!(sites.len(), 1);
414    }
415
416    #[test]
417    fn test_multi_site_mutation_selector() {
418        let mut rng = rand::thread_rng();
419        let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
420        let trace = genome.to_trace();
421
422        let selector = MultiSiteMutationSelector::new(3);
423        let sites = selector.select_sites(&trace, &mut rng);
424
425        assert_eq!(sites.len(), 3);
426    }
427
428    #[test]
429    fn test_mutate_trace() {
430        let mut rng = rand::thread_rng();
431        let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
432
433        let selector = UniformMutationSelector::new(1.0); // Mutate all
434        let mutation_fn = gaussian_mutation(0.1);
435
436        let mutated = mutate_trace(&genome, &selector, mutation_fn, &mut rng).unwrap();
437
438        // Should have same dimension
439        assert_eq!(mutated.dimension(), genome.dimension());
440        // Values should have changed (with high probability)
441    }
442
443    #[test]
444    fn test_crossover_traces() {
445        let mut rng = rand::thread_rng();
446        let parent1 = RealVector::new(vec![1.0, 2.0, 3.0]);
447        let parent2 = RealVector::new(vec![4.0, 5.0, 6.0]);
448
449        let trace1 = parent1.to_trace();
450        let mask = UniformCrossoverMask::balanced(&trace1, &mut rng);
451
452        let (child1, child2) = crossover_traces(&parent1, &parent2, &mask, &mut rng).unwrap();
453
454        assert_eq!(child1.dimension(), 3);
455        assert_eq!(child2.dimension(), 3);
456
457        // Children should have values from either parent
458        for i in 0..3 {
459            let c1_val = child1.genes()[i];
460            let c2_val = child2.genes()[i];
461            let p1_val = parent1.genes()[i];
462            let p2_val = parent2.genes()[i];
463
464            assert!(
465                (c1_val - p1_val).abs() < 1e-10 || (c1_val - p2_val).abs() < 1e-10,
466                "Child1 value {} not from either parent ({} or {})",
467                c1_val,
468                p1_val,
469                p2_val
470            );
471            assert!(
472                (c2_val - p1_val).abs() < 1e-10 || (c2_val - p2_val).abs() < 1e-10,
473                "Child2 value {} not from either parent ({} or {})",
474                c2_val,
475                p1_val,
476                p2_val
477            );
478        }
479    }
480
481    #[test]
482    fn test_single_point_crossover_mask() {
483        let mut rng = rand::thread_rng();
484        let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
485        let trace = genome.to_trace();
486
487        for _ in 0..10 {
488            let mask = SinglePointCrossoverMask::new(&trace, &mut rng);
489
490            // Check that it's a valid partition (addresses are either all before or all after)
491            let addresses: Vec<_> = trace.choices.keys().collect();
492            let mut found_split = false;
493
494            for i in 1..addresses.len() {
495                let prev_from_p1 = mask.from_parent1(addresses[i - 1]);
496                let curr_from_p1 = mask.from_parent1(addresses[i]);
497
498                if prev_from_p1 && !curr_from_p1 {
499                    assert!(!found_split, "Multiple splits found");
500                    found_split = true;
501                }
502            }
503        }
504    }
505}