fugue_evo/fugue_integration/
trace_operators.rs1use std::collections::HashSet;
7
8use fugue::{Address, ChoiceValue, Trace};
9use rand::Rng;
10
11use crate::error::GenomeError;
12use crate::genome::traits::EvolutionaryGenome;
13
14pub trait MutationSelector: Send + Sync {
16 fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address>;
18}
19
20#[derive(Clone, Debug)]
24pub struct UniformMutationSelector {
25 pub mutation_probability: f64,
27}
28
29impl UniformMutationSelector {
30 pub fn new(probability: f64) -> Self {
32 Self {
33 mutation_probability: probability.clamp(0.0, 1.0),
34 }
35 }
36
37 pub fn one_over_n(n: usize) -> Self {
39 Self::new(1.0 / n as f64)
40 }
41}
42
43impl MutationSelector for UniformMutationSelector {
44 fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address> {
45 trace
46 .choices
47 .keys()
48 .filter(|_| rng.gen::<f64>() < self.mutation_probability)
49 .cloned()
50 .collect()
51 }
52}
53
54#[derive(Clone, Debug, Default)]
58pub struct SingleSiteMutationSelector;
59
60impl SingleSiteMutationSelector {
61 pub fn new() -> Self {
63 Self
64 }
65}
66
67impl MutationSelector for SingleSiteMutationSelector {
68 fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address> {
69 let addresses: Vec<_> = trace.choices.keys().collect();
70 if addresses.is_empty() {
71 return HashSet::new();
72 }
73
74 let idx = rng.gen_range(0..addresses.len());
75 let mut sites = HashSet::new();
76 sites.insert(addresses[idx].clone());
77 sites
78 }
79}
80
81#[derive(Clone, Debug)]
85pub struct MultiSiteMutationSelector {
86 pub num_sites: usize,
88}
89
90impl MultiSiteMutationSelector {
91 pub fn new(num_sites: usize) -> Self {
93 Self { num_sites }
94 }
95}
96
97impl MutationSelector for MultiSiteMutationSelector {
98 fn select_sites<R: Rng>(&self, trace: &Trace, rng: &mut R) -> HashSet<Address> {
99 let addresses: Vec<_> = trace.choices.keys().collect();
100 if addresses.is_empty() {
101 return HashSet::new();
102 }
103
104 let k = self.num_sites.min(addresses.len());
105 let mut selected = HashSet::new();
106 let mut indices: Vec<usize> = (0..addresses.len()).collect();
107
108 for i in 0..k {
110 let j = rng.gen_range(i..addresses.len());
111 indices.swap(i, j);
112 selected.insert(addresses[indices[i]].clone());
113 }
114
115 selected
116 }
117}
118
119pub trait CrossoverMask: Send + Sync {
121 fn from_parent1(&self, addr: &Address) -> bool;
123}
124
125#[derive(Clone, Debug)]
129pub struct UniformCrossoverMask {
130 pub bias: f64,
132 selected: HashSet<Address>,
134}
135
136impl UniformCrossoverMask {
137 pub fn new<R: Rng>(bias: f64, trace: &Trace, rng: &mut R) -> Self {
139 let selected = trace
140 .choices
141 .keys()
142 .filter(|_| rng.gen::<f64>() < bias)
143 .cloned()
144 .collect();
145
146 Self { bias, selected }
147 }
148
149 pub fn balanced<R: Rng>(trace: &Trace, rng: &mut R) -> Self {
151 Self::new(0.5, trace, rng)
152 }
153}
154
155impl CrossoverMask for UniformCrossoverMask {
156 fn from_parent1(&self, addr: &Address) -> bool {
157 self.selected.contains(addr)
158 }
159}
160
161#[derive(Clone, Debug)]
166pub struct SinglePointCrossoverMask {
167 parent1_addresses: HashSet<Address>,
169}
170
171impl SinglePointCrossoverMask {
172 pub fn new<R: Rng>(trace: &Trace, rng: &mut R) -> Self {
174 let addresses: Vec<_> = trace.choices.keys().cloned().collect();
175 if addresses.is_empty() {
176 return Self {
177 parent1_addresses: HashSet::new(),
178 };
179 }
180
181 let crossover_point = rng.gen_range(0..=addresses.len());
182 let parent1_addresses: HashSet<_> = addresses.into_iter().take(crossover_point).collect();
183
184 Self { parent1_addresses }
185 }
186}
187
188impl CrossoverMask for SinglePointCrossoverMask {
189 fn from_parent1(&self, addr: &Address) -> bool {
190 self.parent1_addresses.contains(addr)
191 }
192}
193
194#[derive(Clone, Debug)]
199pub struct TwoPointCrossoverMask {
200 parent1_addresses: HashSet<Address>,
202}
203
204impl TwoPointCrossoverMask {
205 pub fn new<R: Rng>(trace: &Trace, rng: &mut R) -> Self {
207 let addresses: Vec<_> = trace.choices.keys().cloned().collect();
208 if addresses.is_empty() {
209 return Self {
210 parent1_addresses: HashSet::new(),
211 };
212 }
213
214 let mut point1 = rng.gen_range(0..addresses.len());
215 let mut point2 = rng.gen_range(0..addresses.len());
216 if point1 > point2 {
217 std::mem::swap(&mut point1, &mut point2);
218 }
219
220 let parent1_addresses: HashSet<_> = addresses
222 .iter()
223 .enumerate()
224 .filter(|(i, _)| *i < point1 || *i >= point2)
225 .map(|(_, addr)| addr.clone())
226 .collect();
227
228 Self { parent1_addresses }
229 }
230}
231
232impl CrossoverMask for TwoPointCrossoverMask {
233 fn from_parent1(&self, addr: &Address) -> bool {
234 self.parent1_addresses.contains(addr)
235 }
236}
237
238pub fn mutate_trace<G, S, R>(
242 genome: &G,
243 selector: &S,
244 mutation_fn: impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue,
245 rng: &mut R,
246) -> Result<G, GenomeError>
247where
248 G: EvolutionaryGenome,
249 S: MutationSelector,
250 R: Rng,
251{
252 let trace = genome.to_trace();
253 let mutation_sites = selector.select_sites(&trace, rng);
254
255 let mut new_trace = Trace::default();
256
257 for (addr, choice) in &trace.choices {
258 let new_value = if mutation_sites.contains(addr) {
259 mutation_fn(addr, &choice.value, rng)
260 } else {
261 choice.value.clone()
262 };
263 new_trace.insert_choice(addr.clone(), new_value, choice.logp);
264 }
265
266 G::from_trace(&new_trace)
267}
268
269pub fn crossover_traces<G, M, R>(
273 parent1: &G,
274 parent2: &G,
275 mask: &M,
276 _rng: &mut R,
277) -> Result<(G, G), GenomeError>
278where
279 G: EvolutionaryGenome,
280 M: CrossoverMask,
281 R: Rng,
282{
283 let trace1 = parent1.to_trace();
284 let trace2 = parent2.to_trace();
285
286 let mut child1_trace = Trace::default();
287 let mut child2_trace = Trace::default();
288
289 let all_addresses: HashSet<Address> = trace1
291 .choices
292 .keys()
293 .chain(trace2.choices.keys())
294 .cloned()
295 .collect();
296
297 for addr in all_addresses {
298 let (val_for_child1, val_for_child2) = if mask.from_parent1(&addr) {
299 (
301 trace1
302 .choices
303 .get(&addr)
304 .map(|c| c.value.clone())
305 .unwrap_or(ChoiceValue::F64(0.0)),
306 trace2
307 .choices
308 .get(&addr)
309 .map(|c| c.value.clone())
310 .unwrap_or(ChoiceValue::F64(0.0)),
311 )
312 } else {
313 (
315 trace2
316 .choices
317 .get(&addr)
318 .map(|c| c.value.clone())
319 .unwrap_or(ChoiceValue::F64(0.0)),
320 trace1
321 .choices
322 .get(&addr)
323 .map(|c| c.value.clone())
324 .unwrap_or(ChoiceValue::F64(0.0)),
325 )
326 };
327
328 child1_trace.insert_choice(addr.clone(), val_for_child1, 0.0);
329 child2_trace.insert_choice(addr, val_for_child2, 0.0);
330 }
331
332 let child1 = G::from_trace(&child1_trace)?;
333 let child2 = G::from_trace(&child2_trace)?;
334
335 Ok((child1, child2))
336}
337
338pub fn gaussian_mutation<R: Rng>(
340 sigma: f64,
341) -> impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue {
342 move |_addr, value, rng| {
343 if let ChoiceValue::F64(v) = value {
344 let noise: f64 = rng.gen::<f64>() * 2.0 - 1.0; let mutated = v + sigma * noise * 2.0_f64.sqrt(); ChoiceValue::F64(mutated)
347 } else {
348 value.clone()
349 }
350 }
351}
352
353pub fn bit_flip_mutation<R: Rng>() -> impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue {
355 move |_addr, value, _rng| {
356 if let ChoiceValue::Bool(b) = value {
357 ChoiceValue::Bool(!b)
358 } else {
359 value.clone()
360 }
361 }
362}
363
364pub fn bounded_mutation<R: Rng>(
366 sigma: f64,
367 lower: f64,
368 upper: f64,
369) -> impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue {
370 move |_addr, value, rng| {
371 if let ChoiceValue::F64(v) = value {
372 let noise: f64 = rng.gen::<f64>() * 2.0 - 1.0;
373 let mutated = (v + sigma * noise * 2.0_f64.sqrt()).clamp(lower, upper);
374 ChoiceValue::F64(mutated)
375 } else {
376 value.clone()
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use crate::genome::real_vector::RealVector;
385 use crate::genome::traits::RealValuedGenome;
386
387 #[test]
388 fn test_uniform_mutation_selector() {
389 let mut rng = rand::thread_rng();
390 let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
391 let trace = genome.to_trace();
392
393 let selector = UniformMutationSelector::new(0.9);
395 let sites = selector.select_sites(&trace, &mut rng);
396 assert!(sites.len() <= 5);
398
399 let selector_low = UniformMutationSelector::new(0.1);
401 let _sites_low = selector_low.select_sites(&trace, &mut rng);
402 }
403
404 #[test]
405 fn test_single_site_mutation_selector() {
406 let mut rng = rand::thread_rng();
407 let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
408 let trace = genome.to_trace();
409
410 let selector = SingleSiteMutationSelector::new();
411 let sites = selector.select_sites(&trace, &mut rng);
412
413 assert_eq!(sites.len(), 1);
414 }
415
416 #[test]
417 fn test_multi_site_mutation_selector() {
418 let mut rng = rand::thread_rng();
419 let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
420 let trace = genome.to_trace();
421
422 let selector = MultiSiteMutationSelector::new(3);
423 let sites = selector.select_sites(&trace, &mut rng);
424
425 assert_eq!(sites.len(), 3);
426 }
427
428 #[test]
429 fn test_mutate_trace() {
430 let mut rng = rand::thread_rng();
431 let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
432
433 let selector = UniformMutationSelector::new(1.0); let mutation_fn = gaussian_mutation(0.1);
435
436 let mutated = mutate_trace(&genome, &selector, mutation_fn, &mut rng).unwrap();
437
438 assert_eq!(mutated.dimension(), genome.dimension());
440 }
442
443 #[test]
444 fn test_crossover_traces() {
445 let mut rng = rand::thread_rng();
446 let parent1 = RealVector::new(vec![1.0, 2.0, 3.0]);
447 let parent2 = RealVector::new(vec![4.0, 5.0, 6.0]);
448
449 let trace1 = parent1.to_trace();
450 let mask = UniformCrossoverMask::balanced(&trace1, &mut rng);
451
452 let (child1, child2) = crossover_traces(&parent1, &parent2, &mask, &mut rng).unwrap();
453
454 assert_eq!(child1.dimension(), 3);
455 assert_eq!(child2.dimension(), 3);
456
457 for i in 0..3 {
459 let c1_val = child1.genes()[i];
460 let c2_val = child2.genes()[i];
461 let p1_val = parent1.genes()[i];
462 let p2_val = parent2.genes()[i];
463
464 assert!(
465 (c1_val - p1_val).abs() < 1e-10 || (c1_val - p2_val).abs() < 1e-10,
466 "Child1 value {} not from either parent ({} or {})",
467 c1_val,
468 p1_val,
469 p2_val
470 );
471 assert!(
472 (c2_val - p1_val).abs() < 1e-10 || (c2_val - p2_val).abs() < 1e-10,
473 "Child2 value {} not from either parent ({} or {})",
474 c2_val,
475 p1_val,
476 p2_val
477 );
478 }
479 }
480
481 #[test]
482 fn test_single_point_crossover_mask() {
483 let mut rng = rand::thread_rng();
484 let genome = RealVector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
485 let trace = genome.to_trace();
486
487 for _ in 0..10 {
488 let mask = SinglePointCrossoverMask::new(&trace, &mut rng);
489
490 let addresses: Vec<_> = trace.choices.keys().collect();
492 let mut found_split = false;
493
494 for i in 1..addresses.len() {
495 let prev_from_p1 = mask.from_parent1(addresses[i - 1]);
496 let curr_from_p1 = mask.from_parent1(addresses[i]);
497
498 if prev_from_p1 && !curr_from_p1 {
499 assert!(!found_split, "Multiple splits found");
500 found_split = true;
501 }
502 }
503 }
504 }
505}