Skip to main content

fugue_evo/fugue_integration/
effect_handlers.rs

1//! Poutine-style effect handlers for evolutionary operations
2//!
3//! Effect handlers intercept and transform evolutionary operations, enabling:
4//! - Logging and tracing of all genetic operations
5//! - Conditional modification of operator behavior
6//! - Composition of effects (e.g., logging + rate limiting)
7//! - Replay of evolutionary traces for debugging
8
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12#[cfg(test)]
13use fugue::addr;
14use fugue::{Address, ChoiceValue, Trace};
15use rand::Rng;
16
17use super::trace_operators::{CrossoverMask, MutationSelector};
18use crate::error::GenomeError;
19use crate::genome::traits::EvolutionaryGenome;
20
21/// Record of a mutation operation
22#[derive(Clone, Debug)]
23pub struct MutationRecord {
24    /// Generation when mutation occurred
25    pub generation: usize,
26    /// Addresses that were mutated
27    pub mutated_addresses: Vec<Address>,
28    /// Original values before mutation
29    pub original_values: HashMap<Address, ChoiceValue>,
30    /// New values after mutation
31    pub new_values: HashMap<Address, ChoiceValue>,
32}
33
34/// Record of a crossover operation
35#[derive(Clone, Debug)]
36pub struct CrossoverRecord {
37    /// Generation when crossover occurred
38    pub generation: usize,
39    /// Addresses from parent1
40    pub from_parent1: Vec<Address>,
41    /// Addresses from parent2
42    pub from_parent2: Vec<Address>,
43}
44
45/// Record of a selection operation
46#[derive(Clone, Debug)]
47pub struct SelectionRecord {
48    /// Generation when selection occurred
49    pub generation: usize,
50    /// Indices of selected individuals
51    pub selected_indices: Vec<usize>,
52    /// Fitness values of selected individuals
53    pub selected_fitness: Vec<f64>,
54}
55
56/// Effect handler trait for mutation operations
57pub trait MutationHandler: Send + Sync {
58    /// Called before mutation is applied
59    /// Returns true if mutation should proceed, false to skip
60    fn before_mutation(&self, trace: &Trace, generation: usize) -> bool;
61
62    /// Called after mutation is applied
63    fn after_mutation(&self, original: &Trace, mutated: &Trace, record: &MutationRecord);
64
65    /// Optionally modify the mutation sites before mutation occurs
66    fn modify_sites(
67        &self,
68        sites: std::collections::HashSet<Address>,
69        _trace: &Trace,
70    ) -> std::collections::HashSet<Address> {
71        sites // Default: no modification
72    }
73}
74
75/// Effect handler trait for crossover operations
76pub trait CrossoverHandler: Send + Sync {
77    /// Called before crossover is applied
78    /// Returns true if crossover should proceed, false to skip
79    fn before_crossover(&self, parent1: &Trace, parent2: &Trace, generation: usize) -> bool;
80
81    /// Called after crossover is applied
82    fn after_crossover(
83        &self,
84        parent1: &Trace,
85        parent2: &Trace,
86        child1: &Trace,
87        child2: &Trace,
88        record: &CrossoverRecord,
89    );
90}
91
92/// Effect handler trait for selection operations
93pub trait SelectionHandler: Send + Sync {
94    /// Called before selection
95    fn before_selection(&self, population_size: usize, generation: usize);
96
97    /// Called after selection
98    fn after_selection(&self, record: &SelectionRecord);
99
100    /// Optionally modify selection probabilities
101    fn modify_probabilities(&self, probabilities: Vec<f64>) -> Vec<f64> {
102        probabilities // Default: no modification
103    }
104}
105
106/// A handler that logs all evolutionary operations
107#[derive(Clone, Debug, Default)]
108pub struct LoggingHandler {
109    /// Mutation records
110    pub mutations: Arc<Mutex<Vec<MutationRecord>>>,
111    /// Crossover records
112    pub crossovers: Arc<Mutex<Vec<CrossoverRecord>>>,
113    /// Selection records
114    pub selections: Arc<Mutex<Vec<SelectionRecord>>>,
115}
116
117impl LoggingHandler {
118    /// Create a new logging handler
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Get all mutation records
124    pub fn get_mutations(&self) -> Vec<MutationRecord> {
125        self.mutations.lock().unwrap().clone()
126    }
127
128    /// Get all crossover records
129    pub fn get_crossovers(&self) -> Vec<CrossoverRecord> {
130        self.crossovers.lock().unwrap().clone()
131    }
132
133    /// Get all selection records
134    pub fn get_selections(&self) -> Vec<SelectionRecord> {
135        self.selections.lock().unwrap().clone()
136    }
137
138    /// Clear all records
139    pub fn clear(&self) {
140        self.mutations.lock().unwrap().clear();
141        self.crossovers.lock().unwrap().clear();
142        self.selections.lock().unwrap().clear();
143    }
144}
145
146impl MutationHandler for LoggingHandler {
147    fn before_mutation(&self, _trace: &Trace, _generation: usize) -> bool {
148        true // Always allow
149    }
150
151    fn after_mutation(&self, _original: &Trace, _mutated: &Trace, record: &MutationRecord) {
152        self.mutations.lock().unwrap().push(record.clone());
153    }
154}
155
156impl CrossoverHandler for LoggingHandler {
157    fn before_crossover(&self, _parent1: &Trace, _parent2: &Trace, _generation: usize) -> bool {
158        true // Always allow
159    }
160
161    fn after_crossover(
162        &self,
163        _parent1: &Trace,
164        _parent2: &Trace,
165        _child1: &Trace,
166        _child2: &Trace,
167        record: &CrossoverRecord,
168    ) {
169        self.crossovers.lock().unwrap().push(record.clone());
170    }
171}
172
173impl SelectionHandler for LoggingHandler {
174    fn before_selection(&self, _population_size: usize, _generation: usize) {}
175
176    fn after_selection(&self, record: &SelectionRecord) {
177        self.selections.lock().unwrap().push(record.clone());
178    }
179}
180
181/// A handler that rate-limits operations
182#[derive(Clone, Debug)]
183pub struct RateLimitingHandler {
184    /// Maximum mutations per generation
185    pub max_mutations: usize,
186    /// Maximum crossovers per generation
187    pub max_crossovers: usize,
188    /// Current mutation count for this generation
189    mutation_count: Arc<Mutex<(usize, usize)>>, // (generation, count)
190    /// Current crossover count for this generation
191    crossover_count: Arc<Mutex<(usize, usize)>>,
192}
193
194impl RateLimitingHandler {
195    /// Create a new rate limiting handler
196    pub fn new(max_mutations: usize, max_crossovers: usize) -> Self {
197        Self {
198            max_mutations,
199            max_crossovers,
200            mutation_count: Arc::new(Mutex::new((0, 0))),
201            crossover_count: Arc::new(Mutex::new((0, 0))),
202        }
203    }
204
205    /// Reset counters for a new generation
206    pub fn reset(&self, generation: usize) {
207        *self.mutation_count.lock().unwrap() = (generation, 0);
208        *self.crossover_count.lock().unwrap() = (generation, 0);
209    }
210}
211
212impl MutationHandler for RateLimitingHandler {
213    fn before_mutation(&self, _trace: &Trace, generation: usize) -> bool {
214        let mut count = self.mutation_count.lock().unwrap();
215        if count.0 != generation {
216            *count = (generation, 0);
217        }
218        if count.1 < self.max_mutations {
219            count.1 += 1;
220            true
221        } else {
222            false
223        }
224    }
225
226    fn after_mutation(&self, _original: &Trace, _mutated: &Trace, _record: &MutationRecord) {}
227}
228
229impl CrossoverHandler for RateLimitingHandler {
230    fn before_crossover(&self, _parent1: &Trace, _parent2: &Trace, generation: usize) -> bool {
231        let mut count = self.crossover_count.lock().unwrap();
232        if count.0 != generation {
233            *count = (generation, 0);
234        }
235        if count.1 < self.max_crossovers {
236            count.1 += 1;
237            true
238        } else {
239            false
240        }
241    }
242
243    fn after_crossover(
244        &self,
245        _parent1: &Trace,
246        _parent2: &Trace,
247        _child1: &Trace,
248        _child2: &Trace,
249        _record: &CrossoverRecord,
250    ) {
251    }
252}
253
254/// A handler that conditionally blocks operations based on a predicate
255pub struct ConditionalHandler<F>
256where
257    F: Fn(usize) -> bool + Send + Sync,
258{
259    /// Predicate that determines if operation should proceed
260    pub predicate: F,
261}
262
263impl<F> ConditionalHandler<F>
264where
265    F: Fn(usize) -> bool + Send + Sync,
266{
267    /// Create a new conditional handler
268    pub fn new(predicate: F) -> Self {
269        Self { predicate }
270    }
271}
272
273impl<F> MutationHandler for ConditionalHandler<F>
274where
275    F: Fn(usize) -> bool + Send + Sync,
276{
277    fn before_mutation(&self, _trace: &Trace, generation: usize) -> bool {
278        (self.predicate)(generation)
279    }
280
281    fn after_mutation(&self, _original: &Trace, _mutated: &Trace, _record: &MutationRecord) {}
282}
283
284impl<F> CrossoverHandler for ConditionalHandler<F>
285where
286    F: Fn(usize) -> bool + Send + Sync,
287{
288    fn before_crossover(&self, _parent1: &Trace, _parent2: &Trace, generation: usize) -> bool {
289        (self.predicate)(generation)
290    }
291
292    fn after_crossover(
293        &self,
294        _parent1: &Trace,
295        _parent2: &Trace,
296        _child1: &Trace,
297        _child2: &Trace,
298        _record: &CrossoverRecord,
299    ) {
300    }
301}
302
303/// Composition of multiple mutation handlers
304pub struct ComposedMutationHandler {
305    handlers: Vec<Box<dyn MutationHandler>>,
306}
307
308impl ComposedMutationHandler {
309    /// Create a new composed handler
310    pub fn new() -> Self {
311        Self {
312            handlers: Vec::new(),
313        }
314    }
315
316    /// Add a handler to the composition
317    pub fn add<H: MutationHandler + 'static>(mut self, handler: H) -> Self {
318        self.handlers.push(Box::new(handler));
319        self
320    }
321}
322
323impl Default for ComposedMutationHandler {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329impl MutationHandler for ComposedMutationHandler {
330    fn before_mutation(&self, trace: &Trace, generation: usize) -> bool {
331        // All handlers must agree
332        self.handlers
333            .iter()
334            .all(|h| h.before_mutation(trace, generation))
335    }
336
337    fn after_mutation(&self, original: &Trace, mutated: &Trace, record: &MutationRecord) {
338        for handler in &self.handlers {
339            handler.after_mutation(original, mutated, record);
340        }
341    }
342
343    fn modify_sites(
344        &self,
345        mut sites: std::collections::HashSet<Address>,
346        trace: &Trace,
347    ) -> std::collections::HashSet<Address> {
348        for handler in &self.handlers {
349            sites = handler.modify_sites(sites, trace);
350        }
351        sites
352    }
353}
354
355/// Composition of multiple crossover handlers
356pub struct ComposedCrossoverHandler {
357    handlers: Vec<Box<dyn CrossoverHandler>>,
358}
359
360impl ComposedCrossoverHandler {
361    /// Create a new composed handler
362    pub fn new() -> Self {
363        Self {
364            handlers: Vec::new(),
365        }
366    }
367
368    /// Add a handler to the composition
369    pub fn add<H: CrossoverHandler + 'static>(mut self, handler: H) -> Self {
370        self.handlers.push(Box::new(handler));
371        self
372    }
373}
374
375impl Default for ComposedCrossoverHandler {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381impl CrossoverHandler for ComposedCrossoverHandler {
382    fn before_crossover(&self, parent1: &Trace, parent2: &Trace, generation: usize) -> bool {
383        self.handlers
384            .iter()
385            .all(|h| h.before_crossover(parent1, parent2, generation))
386    }
387
388    fn after_crossover(
389        &self,
390        parent1: &Trace,
391        parent2: &Trace,
392        child1: &Trace,
393        child2: &Trace,
394        record: &CrossoverRecord,
395    ) {
396        for handler in &self.handlers {
397            handler.after_crossover(parent1, parent2, child1, child2, record);
398        }
399    }
400}
401
402/// Handled mutation operator that integrates with effect handlers
403pub fn handled_mutate_trace<G, S, H, R>(
404    genome: &G,
405    selector: &S,
406    mutation_fn: impl Fn(&Address, &ChoiceValue, &mut R) -> ChoiceValue,
407    handler: &H,
408    generation: usize,
409    rng: &mut R,
410) -> Result<G, GenomeError>
411where
412    G: EvolutionaryGenome,
413    S: MutationSelector,
414    H: MutationHandler,
415    R: Rng,
416{
417    let trace = genome.to_trace();
418
419    // Check if mutation should proceed
420    if !handler.before_mutation(&trace, generation) {
421        return G::from_trace(&trace);
422    }
423
424    // Select and potentially modify mutation sites
425    let mut mutation_sites = selector.select_sites(&trace, rng);
426    mutation_sites = handler.modify_sites(mutation_sites, &trace);
427
428    let mut new_trace = Trace::default();
429    let mut original_values = HashMap::new();
430    let mut new_values = HashMap::new();
431
432    for (addr, choice) in &trace.choices {
433        let new_value = if mutation_sites.contains(addr) {
434            original_values.insert(addr.clone(), choice.value.clone());
435            let mutated = mutation_fn(addr, &choice.value, rng);
436            new_values.insert(addr.clone(), mutated.clone());
437            mutated
438        } else {
439            choice.value.clone()
440        };
441        new_trace.insert_choice(addr.clone(), new_value, choice.logp);
442    }
443
444    // Create mutation record
445    let record = MutationRecord {
446        generation,
447        mutated_addresses: mutation_sites.into_iter().collect(),
448        original_values,
449        new_values,
450    };
451
452    handler.after_mutation(&trace, &new_trace, &record);
453
454    G::from_trace(&new_trace)
455}
456
457/// Handled crossover operator that integrates with effect handlers
458pub fn handled_crossover_traces<G, M, H, R>(
459    parent1: &G,
460    parent2: &G,
461    mask: &M,
462    handler: &H,
463    generation: usize,
464    _rng: &mut R,
465) -> Result<(G, G), GenomeError>
466where
467    G: EvolutionaryGenome,
468    M: CrossoverMask,
469    H: CrossoverHandler,
470    R: Rng,
471{
472    let trace1 = parent1.to_trace();
473    let trace2 = parent2.to_trace();
474
475    // Check if crossover should proceed
476    if !handler.before_crossover(&trace1, &trace2, generation) {
477        return Ok((G::from_trace(&trace1)?, G::from_trace(&trace2)?));
478    }
479
480    let mut child1_trace = Trace::default();
481    let mut child2_trace = Trace::default();
482    let mut from_parent1 = Vec::new();
483    let mut from_parent2 = Vec::new();
484
485    // Collect all addresses from both parents
486    let all_addresses: std::collections::HashSet<Address> = trace1
487        .choices
488        .keys()
489        .chain(trace2.choices.keys())
490        .cloned()
491        .collect();
492
493    for addr in all_addresses {
494        let (val_for_child1, val_for_child2) = if mask.from_parent1(&addr) {
495            from_parent1.push(addr.clone());
496            (
497                trace1
498                    .choices
499                    .get(&addr)
500                    .map(|c| c.value.clone())
501                    .unwrap_or(ChoiceValue::F64(0.0)),
502                trace2
503                    .choices
504                    .get(&addr)
505                    .map(|c| c.value.clone())
506                    .unwrap_or(ChoiceValue::F64(0.0)),
507            )
508        } else {
509            from_parent2.push(addr.clone());
510            (
511                trace2
512                    .choices
513                    .get(&addr)
514                    .map(|c| c.value.clone())
515                    .unwrap_or(ChoiceValue::F64(0.0)),
516                trace1
517                    .choices
518                    .get(&addr)
519                    .map(|c| c.value.clone())
520                    .unwrap_or(ChoiceValue::F64(0.0)),
521            )
522        };
523
524        child1_trace.insert_choice(addr.clone(), val_for_child1, 0.0);
525        child2_trace.insert_choice(addr, val_for_child2, 0.0);
526    }
527
528    // Create crossover record
529    let record = CrossoverRecord {
530        generation,
531        from_parent1,
532        from_parent2,
533    };
534
535    handler.after_crossover(&trace1, &trace2, &child1_trace, &child2_trace, &record);
536
537    let child1 = G::from_trace(&child1_trace)?;
538    let child2 = G::from_trace(&child2_trace)?;
539
540    Ok((child1, child2))
541}
542
543/// Statistics computed from handler records
544#[derive(Clone, Debug, Default)]
545pub struct OperationStatistics {
546    /// Total number of mutations
547    pub total_mutations: usize,
548    /// Total number of crossovers
549    pub total_crossovers: usize,
550    /// Average mutation sites per operation
551    pub avg_mutation_sites: f64,
552    /// Distribution of addresses from parent1 in crossovers
553    pub avg_parent1_contribution: f64,
554}
555
556impl OperationStatistics {
557    /// Compute statistics from a logging handler
558    pub fn from_handler(handler: &LoggingHandler) -> Self {
559        let mutations = handler.get_mutations();
560        let crossovers = handler.get_crossovers();
561
562        let total_mutations = mutations.len();
563        let total_crossovers = crossovers.len();
564
565        let avg_mutation_sites = if total_mutations > 0 {
566            mutations
567                .iter()
568                .map(|r| r.mutated_addresses.len())
569                .sum::<usize>() as f64
570                / total_mutations as f64
571        } else {
572            0.0
573        };
574
575        let avg_parent1_contribution = if total_crossovers > 0 {
576            let total: f64 = crossovers
577                .iter()
578                .map(|r| {
579                    let total = r.from_parent1.len() + r.from_parent2.len();
580                    if total > 0 {
581                        r.from_parent1.len() as f64 / total as f64
582                    } else {
583                        0.5
584                    }
585                })
586                .sum();
587            total / total_crossovers as f64
588        } else {
589            0.0
590        };
591
592        Self {
593            total_mutations,
594            total_crossovers,
595            avg_mutation_sites,
596            avg_parent1_contribution,
597        }
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use crate::fugue_integration::trace_operators::{
605        gaussian_mutation, UniformCrossoverMask, UniformMutationSelector,
606    };
607    use crate::genome::real_vector::RealVector;
608
609    #[test]
610    fn test_logging_handler_mutation() {
611        let mut rng = rand::thread_rng();
612        let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
613
614        let handler = LoggingHandler::new();
615        let selector = UniformMutationSelector::new(1.0); // Mutate all
616        let mutation_fn = gaussian_mutation(0.1);
617
618        let _mutated =
619            handled_mutate_trace(&genome, &selector, mutation_fn, &handler, 0, &mut rng).unwrap();
620
621        let records = handler.get_mutations();
622        assert_eq!(records.len(), 1);
623        assert_eq!(records[0].generation, 0);
624        assert!(!records[0].mutated_addresses.is_empty());
625    }
626
627    #[test]
628    fn test_logging_handler_crossover() {
629        let mut rng = rand::thread_rng();
630        let parent1 = RealVector::new(vec![1.0, 2.0, 3.0]);
631        let parent2 = RealVector::new(vec![4.0, 5.0, 6.0]);
632
633        let handler = LoggingHandler::new();
634        let trace1 = parent1.to_trace();
635        let mask = UniformCrossoverMask::balanced(&trace1, &mut rng);
636
637        let (_child1, _child2) =
638            handled_crossover_traces(&parent1, &parent2, &mask, &handler, 0, &mut rng).unwrap();
639
640        let records = handler.get_crossovers();
641        assert_eq!(records.len(), 1);
642        assert_eq!(records[0].generation, 0);
643    }
644
645    #[test]
646    fn test_rate_limiting_handler() {
647        let mut rng = rand::thread_rng();
648        let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
649
650        let handler = RateLimitingHandler::new(2, 2);
651        let selector = UniformMutationSelector::new(1.0);
652        let mutation_fn = gaussian_mutation(0.1);
653
654        // First two mutations should succeed
655        for _ in 0..2 {
656            let result =
657                handled_mutate_trace(&genome, &selector, &mutation_fn, &handler, 0, &mut rng);
658            assert!(result.is_ok());
659        }
660
661        // Third mutation should be skipped (returns original)
662        let result = handled_mutate_trace(&genome, &selector, &mutation_fn, &handler, 0, &mut rng);
663        assert!(result.is_ok());
664    }
665
666    #[test]
667    fn test_composed_handler() {
668        let mut rng = rand::thread_rng();
669        let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
670
671        let logging = LoggingHandler::new();
672        let rate_limit = RateLimitingHandler::new(5, 5);
673
674        let composed = ComposedMutationHandler::new()
675            .add(logging.clone())
676            .add(rate_limit);
677
678        let selector = UniformMutationSelector::new(1.0);
679        let mutation_fn = gaussian_mutation(0.1);
680
681        let _ = handled_mutate_trace(&genome, &selector, mutation_fn, &composed, 0, &mut rng);
682
683        // Logging handler should have recorded the mutation
684        assert_eq!(logging.get_mutations().len(), 1);
685    }
686
687    #[test]
688    fn test_operation_statistics() {
689        let handler = LoggingHandler::new();
690
691        // Manually add some records
692        handler.mutations.lock().unwrap().push(MutationRecord {
693            generation: 0,
694            mutated_addresses: vec![addr!("test", 0), addr!("test", 1)],
695            original_values: HashMap::new(),
696            new_values: HashMap::new(),
697        });
698
699        let stats = OperationStatistics::from_handler(&handler);
700        assert_eq!(stats.total_mutations, 1);
701        assert_eq!(stats.avg_mutation_sites, 2.0);
702    }
703}