1use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12#[cfg(test)]
13use fugue::addr;
14use fugue::{Address, ChoiceValue, Trace};
15use rand::Rng;
16
17use super::trace_operators::{CrossoverMask, MutationSelector};
18use crate::error::GenomeError;
19use crate::genome::traits::EvolutionaryGenome;
20
21#[derive(Clone, Debug)]
23pub struct MutationRecord {
24 pub generation: usize,
26 pub mutated_addresses: Vec<Address>,
28 pub original_values: HashMap<Address, ChoiceValue>,
30 pub new_values: HashMap<Address, ChoiceValue>,
32}
33
34#[derive(Clone, Debug)]
36pub struct CrossoverRecord {
37 pub generation: usize,
39 pub from_parent1: Vec<Address>,
41 pub from_parent2: Vec<Address>,
43}
44
45#[derive(Clone, Debug)]
47pub struct SelectionRecord {
48 pub generation: usize,
50 pub selected_indices: Vec<usize>,
52 pub selected_fitness: Vec<f64>,
54}
55
56pub trait MutationHandler: Send + Sync {
58 fn before_mutation(&self, trace: &Trace, generation: usize) -> bool;
61
62 fn after_mutation(&self, original: &Trace, mutated: &Trace, record: &MutationRecord);
64
65 fn modify_sites(
67 &self,
68 sites: std::collections::HashSet<Address>,
69 _trace: &Trace,
70 ) -> std::collections::HashSet<Address> {
71 sites }
73}
74
75pub trait CrossoverHandler: Send + Sync {
77 fn before_crossover(&self, parent1: &Trace, parent2: &Trace, generation: usize) -> bool;
80
81 fn after_crossover(
83 &self,
84 parent1: &Trace,
85 parent2: &Trace,
86 child1: &Trace,
87 child2: &Trace,
88 record: &CrossoverRecord,
89 );
90}
91
92pub trait SelectionHandler: Send + Sync {
94 fn before_selection(&self, population_size: usize, generation: usize);
96
97 fn after_selection(&self, record: &SelectionRecord);
99
100 fn modify_probabilities(&self, probabilities: Vec<f64>) -> Vec<f64> {
102 probabilities }
104}
105
106#[derive(Clone, Debug, Default)]
108pub struct LoggingHandler {
109 pub mutations: Arc<Mutex<Vec<MutationRecord>>>,
111 pub crossovers: Arc<Mutex<Vec<CrossoverRecord>>>,
113 pub selections: Arc<Mutex<Vec<SelectionRecord>>>,
115}
116
117impl LoggingHandler {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn get_mutations(&self) -> Vec<MutationRecord> {
125 self.mutations.lock().unwrap().clone()
126 }
127
128 pub fn get_crossovers(&self) -> Vec<CrossoverRecord> {
130 self.crossovers.lock().unwrap().clone()
131 }
132
133 pub fn get_selections(&self) -> Vec<SelectionRecord> {
135 self.selections.lock().unwrap().clone()
136 }
137
138 pub fn clear(&self) {
140 self.mutations.lock().unwrap().clear();
141 self.crossovers.lock().unwrap().clear();
142 self.selections.lock().unwrap().clear();
143 }
144}
145
146impl MutationHandler for LoggingHandler {
147 fn before_mutation(&self, _trace: &Trace, _generation: usize) -> bool {
148 true }
150
151 fn after_mutation(&self, _original: &Trace, _mutated: &Trace, record: &MutationRecord) {
152 self.mutations.lock().unwrap().push(record.clone());
153 }
154}
155
156impl CrossoverHandler for LoggingHandler {
157 fn before_crossover(&self, _parent1: &Trace, _parent2: &Trace, _generation: usize) -> bool {
158 true }
160
161 fn after_crossover(
162 &self,
163 _parent1: &Trace,
164 _parent2: &Trace,
165 _child1: &Trace,
166 _child2: &Trace,
167 record: &CrossoverRecord,
168 ) {
169 self.crossovers.lock().unwrap().push(record.clone());
170 }
171}
172
173impl SelectionHandler for LoggingHandler {
174 fn before_selection(&self, _population_size: usize, _generation: usize) {}
175
176 fn after_selection(&self, record: &SelectionRecord) {
177 self.selections.lock().unwrap().push(record.clone());
178 }
179}
180
181#[derive(Clone, Debug)]
183pub struct RateLimitingHandler {
184 pub max_mutations: usize,
186 pub max_crossovers: usize,
188 mutation_count: Arc<Mutex<(usize, usize)>>, crossover_count: Arc<Mutex<(usize, usize)>>,
192}
193
194impl RateLimitingHandler {
195 pub fn new(max_mutations: usize, max_crossovers: usize) -> Self {
197 Self {
198 max_mutations,
199 max_crossovers,
200 mutation_count: Arc::new(Mutex::new((0, 0))),
201 crossover_count: Arc::new(Mutex::new((0, 0))),
202 }
203 }
204
205 pub fn reset(&self, generation: usize) {
207 *self.mutation_count.lock().unwrap() = (generation, 0);
208 *self.crossover_count.lock().unwrap() = (generation, 0);
209 }
210}
211
212impl MutationHandler for RateLimitingHandler {
213 fn before_mutation(&self, _trace: &Trace, generation: usize) -> bool {
214 let mut count = self.mutation_count.lock().unwrap();
215 if count.0 != generation {
216 *count = (generation, 0);
217 }
218 if count.1 < self.max_mutations {
219 count.1 += 1;
220 true
221 } else {
222 false
223 }
224 }
225
226 fn after_mutation(&self, _original: &Trace, _mutated: &Trace, _record: &MutationRecord) {}
227}
228
229impl CrossoverHandler for RateLimitingHandler {
230 fn before_crossover(&self, _parent1: &Trace, _parent2: &Trace, generation: usize) -> bool {
231 let mut count = self.crossover_count.lock().unwrap();
232 if count.0 != generation {
233 *count = (generation, 0);
234 }
235 if count.1 < self.max_crossovers {
236 count.1 += 1;
237 true
238 } else {
239 false
240 }
241 }
242
243 fn after_crossover(
244 &self,
245 _parent1: &Trace,
246 _parent2: &Trace,
247 _child1: &Trace,
248 _child2: &Trace,
249 _record: &CrossoverRecord,
250 ) {
251 }
252}
253
254pub struct ConditionalHandler<F>
256where
257 F: Fn(usize) -> bool + Send + Sync,
258{
259 pub predicate: F,
261}
262
263impl<F> ConditionalHandler<F>
264where
265 F: Fn(usize) -> bool + Send + Sync,
266{
267 pub fn new(predicate: F) -> Self {
269 Self { predicate }
270 }
271}
272
273impl<F> MutationHandler for ConditionalHandler<F>
274where
275 F: Fn(usize) -> bool + Send + Sync,
276{
277 fn before_mutation(&self, _trace: &Trace, generation: usize) -> bool {
278 (self.predicate)(generation)
279 }
280
281 fn after_mutation(&self, _original: &Trace, _mutated: &Trace, _record: &MutationRecord) {}
282}
283
284impl<F> CrossoverHandler for ConditionalHandler<F>
285where
286 F: Fn(usize) -> bool + Send + Sync,
287{
288 fn before_crossover(&self, _parent1: &Trace, _parent2: &Trace, generation: usize) -> bool {
289 (self.predicate)(generation)
290 }
291
292 fn after_crossover(
293 &self,
294 _parent1: &Trace,
295 _parent2: &Trace,
296 _child1: &Trace,
297 _child2: &Trace,
298 _record: &CrossoverRecord,
299 ) {
300 }
301}
302
303pub struct ComposedMutationHandler {
305 handlers: Vec<Box<dyn MutationHandler>>,
306}
307
308impl ComposedMutationHandler {
309 pub fn new() -> Self {
311 Self {
312 handlers: Vec::new(),
313 }
314 }
315
316 pub fn add<H: MutationHandler + 'static>(mut self, handler: H) -> Self {
318 self.handlers.push(Box::new(handler));
319 self
320 }
321}
322
323impl Default for ComposedMutationHandler {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329impl MutationHandler for ComposedMutationHandler {
330 fn before_mutation(&self, trace: &Trace, generation: usize) -> bool {
331 self.handlers
333 .iter()
334 .all(|h| h.before_mutation(trace, generation))
335 }
336
337 fn after_mutation(&self, original: &Trace, mutated: &Trace, record: &MutationRecord) {
338 for handler in &self.handlers {
339 handler.after_mutation(original, mutated, record);
340 }
341 }
342
343 fn modify_sites(
344 &self,
345 mut sites: std::collections::HashSet<Address>,
346 trace: &Trace,
347 ) -> std::collections::HashSet<Address> {
348 for handler in &self.handlers {
349 sites = handler.modify_sites(sites, trace);
350 }
351 sites
352 }
353}
354
355pub struct ComposedCrossoverHandler {
357 handlers: Vec<Box<dyn CrossoverHandler>>,
358}
359
360impl ComposedCrossoverHandler {
361 pub fn new() -> Self {
363 Self {
364 handlers: Vec::new(),
365 }
366 }
367
368 pub fn add<H: CrossoverHandler + 'static>(mut self, handler: H) -> Self {
370 self.handlers.push(Box::new(handler));
371 self
372 }
373}
374
375impl Default for ComposedCrossoverHandler {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381impl CrossoverHandler for ComposedCrossoverHandler {
382 fn before_crossover(&self, parent1: &Trace, parent2: &Trace, generation: usize) -> bool {
383 self.handlers
384 .iter()
385 .all(|h| h.before_crossover(parent1, parent2, generation))
386 }
387
388 fn after_crossover(
389 &self,
390 parent1: &Trace,
391 parent2: &Trace,
392 child1: &Trace,
393 child2: &Trace,
394 record: &CrossoverRecord,
395 ) {
396 for handler in &self.handlers {
397 handler.after_crossover(parent1, parent2, child1, child2, record);
398 }
399 }
400}
401
402pub fn handled_mutate_trace<G, S, H, R>(
404 genome: &G,
405 selector: &S,
406 mutation_fn: impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue,
407 handler: &H,
408 generation: usize,
409 rng: &mut R,
410) -> Result<G, GenomeError>
411where
412 G: EvolutionaryGenome,
413 S: MutationSelector,
414 H: MutationHandler,
415 R: Rng,
416{
417 let trace = genome.to_trace();
418
419 if !handler.before_mutation(&trace, generation) {
421 return G::from_trace(&trace);
422 }
423
424 let mut mutation_sites = selector.select_sites(&trace, rng);
426 mutation_sites = handler.modify_sites(mutation_sites, &trace);
427
428 let mut new_trace = Trace::default();
429 let mut original_values = HashMap::new();
430 let mut new_values = HashMap::new();
431
432 for (addr, choice) in &trace.choices {
433 let new_value = if mutation_sites.contains(addr) {
434 original_values.insert(addr.clone(), choice.value.clone());
435 let mutated = mutation_fn(addr, &choice.value, rng);
436 new_values.insert(addr.clone(), mutated.clone());
437 mutated
438 } else {
439 choice.value.clone()
440 };
441 new_trace.insert_choice(addr.clone(), new_value, choice.logp);
442 }
443
444 let record = MutationRecord {
446 generation,
447 mutated_addresses: mutation_sites.into_iter().collect(),
448 original_values,
449 new_values,
450 };
451
452 handler.after_mutation(&trace, &new_trace, &record);
453
454 G::from_trace(&new_trace)
455}
456
457pub fn handled_crossover_traces<G, M, H, R>(
459 parent1: &G,
460 parent2: &G,
461 mask: &M,
462 handler: &H,
463 generation: usize,
464 _rng: &mut R,
465) -> Result<(G, G), GenomeError>
466where
467 G: EvolutionaryGenome,
468 M: CrossoverMask,
469 H: CrossoverHandler,
470 R: Rng,
471{
472 let trace1 = parent1.to_trace();
473 let trace2 = parent2.to_trace();
474
475 if !handler.before_crossover(&trace1, &trace2, generation) {
477 return Ok((G::from_trace(&trace1)?, G::from_trace(&trace2)?));
478 }
479
480 let mut child1_trace = Trace::default();
481 let mut child2_trace = Trace::default();
482 let mut from_parent1 = Vec::new();
483 let mut from_parent2 = Vec::new();
484
485 let all_addresses: std::collections::HashSet<Address> = trace1
487 .choices
488 .keys()
489 .chain(trace2.choices.keys())
490 .cloned()
491 .collect();
492
493 for addr in all_addresses {
494 let (val_for_child1, val_for_child2) = if mask.from_parent1(&addr) {
495 from_parent1.push(addr.clone());
496 (
497 trace1
498 .choices
499 .get(&addr)
500 .map(|c| c.value.clone())
501 .unwrap_or(ChoiceValue::F64(0.0)),
502 trace2
503 .choices
504 .get(&addr)
505 .map(|c| c.value.clone())
506 .unwrap_or(ChoiceValue::F64(0.0)),
507 )
508 } else {
509 from_parent2.push(addr.clone());
510 (
511 trace2
512 .choices
513 .get(&addr)
514 .map(|c| c.value.clone())
515 .unwrap_or(ChoiceValue::F64(0.0)),
516 trace1
517 .choices
518 .get(&addr)
519 .map(|c| c.value.clone())
520 .unwrap_or(ChoiceValue::F64(0.0)),
521 )
522 };
523
524 child1_trace.insert_choice(addr.clone(), val_for_child1, 0.0);
525 child2_trace.insert_choice(addr, val_for_child2, 0.0);
526 }
527
528 let record = CrossoverRecord {
530 generation,
531 from_parent1,
532 from_parent2,
533 };
534
535 handler.after_crossover(&trace1, &trace2, &child1_trace, &child2_trace, &record);
536
537 let child1 = G::from_trace(&child1_trace)?;
538 let child2 = G::from_trace(&child2_trace)?;
539
540 Ok((child1, child2))
541}
542
543#[derive(Clone, Debug, Default)]
545pub struct OperationStatistics {
546 pub total_mutations: usize,
548 pub total_crossovers: usize,
550 pub avg_mutation_sites: f64,
552 pub avg_parent1_contribution: f64,
554}
555
556impl OperationStatistics {
557 pub fn from_handler(handler: &LoggingHandler) -> Self {
559 let mutations = handler.get_mutations();
560 let crossovers = handler.get_crossovers();
561
562 let total_mutations = mutations.len();
563 let total_crossovers = crossovers.len();
564
565 let avg_mutation_sites = if total_mutations > 0 {
566 mutations
567 .iter()
568 .map(|r| r.mutated_addresses.len())
569 .sum::<usize>() as f64
570 / total_mutations as f64
571 } else {
572 0.0
573 };
574
575 let avg_parent1_contribution = if total_crossovers > 0 {
576 let total: f64 = crossovers
577 .iter()
578 .map(|r| {
579 let total = r.from_parent1.len() + r.from_parent2.len();
580 if total > 0 {
581 r.from_parent1.len() as f64 / total as f64
582 } else {
583 0.5
584 }
585 })
586 .sum();
587 total / total_crossovers as f64
588 } else {
589 0.0
590 };
591
592 Self {
593 total_mutations,
594 total_crossovers,
595 avg_mutation_sites,
596 avg_parent1_contribution,
597 }
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use crate::fugue_integration::trace_operators::{
605 gaussian_mutation, UniformCrossoverMask, UniformMutationSelector,
606 };
607 use crate::genome::real_vector::RealVector;
608
609 #[test]
610 fn test_logging_handler_mutation() {
611 let mut rng = rand::thread_rng();
612 let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
613
614 let handler = LoggingHandler::new();
615 let selector = UniformMutationSelector::new(1.0); let mutation_fn = gaussian_mutation(0.1);
617
618 let _mutated =
619 handled_mutate_trace(&genome, &selector, mutation_fn, &handler, 0, &mut rng).unwrap();
620
621 let records = handler.get_mutations();
622 assert_eq!(records.len(), 1);
623 assert_eq!(records[0].generation, 0);
624 assert!(!records[0].mutated_addresses.is_empty());
625 }
626
627 #[test]
628 fn test_logging_handler_crossover() {
629 let mut rng = rand::thread_rng();
630 let parent1 = RealVector::new(vec![1.0, 2.0, 3.0]);
631 let parent2 = RealVector::new(vec![4.0, 5.0, 6.0]);
632
633 let handler = LoggingHandler::new();
634 let trace1 = parent1.to_trace();
635 let mask = UniformCrossoverMask::balanced(&trace1, &mut rng);
636
637 let (_child1, _child2) =
638 handled_crossover_traces(&parent1, &parent2, &mask, &handler, 0, &mut rng).unwrap();
639
640 let records = handler.get_crossovers();
641 assert_eq!(records.len(), 1);
642 assert_eq!(records[0].generation, 0);
643 }
644
645 #[test]
646 fn test_rate_limiting_handler() {
647 let mut rng = rand::thread_rng();
648 let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
649
650 let handler = RateLimitingHandler::new(2, 2);
651 let selector = UniformMutationSelector::new(1.0);
652 let mutation_fn = gaussian_mutation(0.1);
653
654 for _ in 0..2 {
656 let result =
657 handled_mutate_trace(&genome, &selector, &mutation_fn, &handler, 0, &mut rng);
658 assert!(result.is_ok());
659 }
660
661 let result = handled_mutate_trace(&genome, &selector, &mutation_fn, &handler, 0, &mut rng);
663 assert!(result.is_ok());
664 }
665
666 #[test]
667 fn test_composed_handler() {
668 let mut rng = rand::thread_rng();
669 let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
670
671 let logging = LoggingHandler::new();
672 let rate_limit = RateLimitingHandler::new(5, 5);
673
674 let composed = ComposedMutationHandler::new()
675 .add(logging.clone())
676 .add(rate_limit);
677
678 let selector = UniformMutationSelector::new(1.0);
679 let mutation_fn = gaussian_mutation(0.1);
680
681 let _ = handled_mutate_trace(&genome, &selector, mutation_fn, &composed, 0, &mut rng);
682
683 assert_eq!(logging.get_mutations().len(), 1);
685 }
686
687 #[test]
688 fn test_operation_statistics() {
689 let handler = LoggingHandler::new();
690
691 handler.mutations.lock().unwrap().push(MutationRecord {
693 generation: 0,
694 mutated_addresses: vec![addr!("test", 0), addr!("test", 1)],
695 original_values: HashMap::new(),
696 new_values: HashMap::new(),
697 });
698
699 let stats = OperationStatistics::from_handler(&handler);
700 assert_eq!(stats.total_mutations, 1);
701 assert_eq!(stats.avg_mutation_sites, 2.0);
702 }
703}