Skip to main content

fugue_evo/checkpoint/
state.rs

1//! Checkpoint state structures
2//!
3//! Complete evolution state for checkpointing and recovery.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use crate::diagnostics::GenerationStats;
9use crate::genome::traits::EvolutionaryGenome;
10use crate::population::individual::Individual;
11
12/// Current checkpoint format version
13pub const CHECKPOINT_VERSION: u32 = 1;
14
15/// Complete evolution state for checkpointing
16#[derive(Clone, Debug, Serialize, Deserialize)]
17#[serde(bound = "")]
18pub struct Checkpoint<G>
19where
20    G: Clone + Serialize + EvolutionaryGenome,
21{
22    /// Schema version for forward compatibility
23    pub version: u32,
24    /// Current generation
25    pub generation: usize,
26    /// Total fitness evaluations
27    pub evaluations: usize,
28    /// Population with fitness values
29    pub population: Vec<Individual<G>>,
30    /// RNG state for reproducibility (serialized bytes)
31    pub rng_state: Option<Vec<u8>>,
32    /// Best individual found so far
33    pub best: Option<Individual<G>>,
34    /// Algorithm-specific state
35    pub algorithm_state: AlgorithmState,
36    /// Hyperparameter state if using adaptive learning
37    pub hyperparameter_state: Option<HyperparameterState>,
38    /// Statistics history
39    pub statistics: Vec<GenerationStats>,
40    /// Custom metadata
41    pub metadata: HashMap<String, String>,
42}
43
44impl<G> Checkpoint<G>
45where
46    G: Clone + Serialize + EvolutionaryGenome,
47{
48    /// Create a new checkpoint
49    pub fn new(generation: usize, population: Vec<Individual<G>>) -> Self {
50        Self {
51            version: CHECKPOINT_VERSION,
52            generation,
53            evaluations: 0,
54            population,
55            rng_state: None,
56            best: None,
57            algorithm_state: AlgorithmState::SimpleGA,
58            hyperparameter_state: None,
59            statistics: Vec::new(),
60            metadata: HashMap::new(),
61        }
62    }
63
64    /// Set the number of evaluations
65    pub fn with_evaluations(mut self, evaluations: usize) -> Self {
66        self.evaluations = evaluations;
67        self
68    }
69
70    /// Set the best individual
71    pub fn with_best(mut self, best: Individual<G>) -> Self {
72        self.best = Some(best);
73        self
74    }
75
76    /// Set algorithm-specific state
77    pub fn with_algorithm_state(mut self, state: AlgorithmState) -> Self {
78        self.algorithm_state = state;
79        self
80    }
81
82    /// Set hyperparameter state
83    pub fn with_hyperparameter_state(mut self, state: HyperparameterState) -> Self {
84        self.hyperparameter_state = Some(state);
85        self
86    }
87
88    /// Add statistics history
89    pub fn with_statistics(mut self, stats: Vec<GenerationStats>) -> Self {
90        self.statistics = stats;
91        self
92    }
93
94    /// Add custom metadata
95    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
96        self.metadata.insert(key.into(), value.into());
97        self
98    }
99
100    /// Set RNG state
101    pub fn with_rng_state(mut self, state: Vec<u8>) -> Self {
102        self.rng_state = Some(state);
103        self
104    }
105
106    /// Check if checkpoint is compatible with current version
107    pub fn is_compatible(&self) -> bool {
108        self.version <= CHECKPOINT_VERSION
109    }
110
111    /// Get the checkpoint version
112    pub fn version(&self) -> u32 {
113        self.version
114    }
115}
116
117/// Algorithm-specific state variants
118#[derive(Clone, Debug, Serialize, Deserialize)]
119pub enum AlgorithmState {
120    /// Simple generational GA (no additional state)
121    SimpleGA,
122    /// Steady-state GA
123    SteadyState { replacement_count: usize },
124    /// CMA-ES state
125    CmaEs(CmaEsCheckpointState),
126    /// NSGA-II state
127    Nsga2 { pareto_front_indices: Vec<usize> },
128    /// HBGA state
129    Hbga {
130        population_params: Vec<f64>,
131        temperature: f64,
132    },
133    /// Island model state
134    Island {
135        island_populations: Vec<Vec<usize>>,
136        migration_count: usize,
137    },
138    /// Interactive GA state
139    Interactive {
140        /// Serialized aggregator state (JSON)
141        aggregator_state: String,
142        /// Number of pending evaluations
143        pending_evaluations: usize,
144        /// Evaluation mode
145        evaluation_mode: String,
146    },
147    /// Custom algorithm state (JSON serialized)
148    Custom(String),
149}
150
151/// CMA-ES checkpoint state
152#[derive(Clone, Debug, Serialize, Deserialize)]
153pub struct CmaEsCheckpointState {
154    /// Mean vector
155    pub mean: Vec<f64>,
156    /// Global step size
157    pub sigma: f64,
158    /// Covariance matrix (flattened row-major)
159    pub covariance: Vec<f64>,
160    /// Evolution path for sigma
161    pub path_sigma: Vec<f64>,
162    /// Evolution path for covariance
163    pub path_c: Vec<f64>,
164    /// Dimension
165    pub dimension: usize,
166}
167
168/// Hyperparameter learning state
169#[derive(Clone, Debug, Serialize, Deserialize)]
170pub struct HyperparameterState {
171    /// Mutation rate posterior (alpha, beta for Beta distribution)
172    pub mutation_rate_posterior: Option<(f64, f64)>,
173    /// Crossover probability posterior
174    pub crossover_prob_posterior: Option<(f64, f64)>,
175    /// Selection temperature posterior (shape, rate for Gamma)
176    pub temperature_posterior: Option<(f64, f64)>,
177    /// Step size posteriors (mu, sigma_sq for LogNormal)
178    pub step_size_posteriors: Vec<(f64, f64)>,
179    /// Operator selection weights
180    pub operator_weights: Vec<f64>,
181    /// History window for learning
182    pub history_size: usize,
183}
184
185impl Default for HyperparameterState {
186    fn default() -> Self {
187        Self {
188            mutation_rate_posterior: None,
189            crossover_prob_posterior: None,
190            temperature_posterior: None,
191            step_size_posteriors: Vec::new(),
192            operator_weights: Vec::new(),
193            history_size: 100,
194        }
195    }
196}
197
198/// Builder for creating checkpoints
199pub struct CheckpointBuilder<G>
200where
201    G: Clone + Serialize + EvolutionaryGenome,
202{
203    checkpoint: Checkpoint<G>,
204}
205
206impl<G> CheckpointBuilder<G>
207where
208    G: Clone + Serialize + EvolutionaryGenome,
209{
210    /// Create a new checkpoint builder
211    pub fn new(generation: usize, population: Vec<Individual<G>>) -> Self {
212        Self {
213            checkpoint: Checkpoint::new(generation, population),
214        }
215    }
216
217    /// Set evaluations count
218    pub fn evaluations(mut self, count: usize) -> Self {
219        self.checkpoint.evaluations = count;
220        self
221    }
222
223    /// Set best individual
224    pub fn best(mut self, individual: Individual<G>) -> Self {
225        self.checkpoint.best = Some(individual);
226        self
227    }
228
229    /// Set algorithm state
230    pub fn algorithm_state(mut self, state: AlgorithmState) -> Self {
231        self.checkpoint.algorithm_state = state;
232        self
233    }
234
235    /// Set hyperparameter state
236    pub fn hyperparameters(mut self, state: HyperparameterState) -> Self {
237        self.checkpoint.hyperparameter_state = Some(state);
238        self
239    }
240
241    /// Set statistics
242    pub fn statistics(mut self, stats: Vec<GenerationStats>) -> Self {
243        self.checkpoint.statistics = stats;
244        self
245    }
246
247    /// Add metadata
248    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
249        self.checkpoint.metadata.insert(key.into(), value.into());
250        self
251    }
252
253    /// Set RNG state
254    pub fn rng_state(mut self, state: Vec<u8>) -> Self {
255        self.checkpoint.rng_state = Some(state);
256        self
257    }
258
259    /// Build the checkpoint
260    pub fn build(self) -> Checkpoint<G> {
261        self.checkpoint
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::genome::real_vector::RealVector;
269
270    #[test]
271    fn test_checkpoint_creation() {
272        let population: Vec<Individual<RealVector>> = vec![
273            Individual::new(RealVector::new(vec![1.0, 2.0])),
274            Individual::new(RealVector::new(vec![3.0, 4.0])),
275        ];
276
277        let checkpoint = Checkpoint::new(10, population.clone());
278
279        assert_eq!(checkpoint.version, CHECKPOINT_VERSION);
280        assert_eq!(checkpoint.generation, 10);
281        assert_eq!(checkpoint.population.len(), 2);
282    }
283
284    #[test]
285    fn test_checkpoint_builder() {
286        let population: Vec<Individual<RealVector>> =
287            vec![Individual::new(RealVector::new(vec![1.0]))];
288
289        let checkpoint = CheckpointBuilder::new(5, population)
290            .evaluations(1000)
291            .algorithm_state(AlgorithmState::SimpleGA)
292            .metadata("experiment", "test_run")
293            .build();
294
295        assert_eq!(checkpoint.generation, 5);
296        assert_eq!(checkpoint.evaluations, 1000);
297        assert_eq!(
298            checkpoint.metadata.get("experiment"),
299            Some(&"test_run".to_string())
300        );
301    }
302
303    #[test]
304    fn test_checkpoint_compatibility() {
305        let population: Vec<Individual<RealVector>> = vec![];
306        let checkpoint = Checkpoint::new(0, population);
307
308        assert!(checkpoint.is_compatible());
309    }
310
311    #[test]
312    fn test_cmaes_checkpoint_state() {
313        let state = CmaEsCheckpointState {
314            mean: vec![0.0, 0.0],
315            sigma: 1.0,
316            covariance: vec![1.0, 0.0, 0.0, 1.0],
317            path_sigma: vec![0.0, 0.0],
318            path_c: vec![0.0, 0.0],
319            dimension: 2,
320        };
321
322        let alg_state = AlgorithmState::CmaEs(state);
323        if let AlgorithmState::CmaEs(s) = alg_state {
324            assert_eq!(s.dimension, 2);
325            assert_eq!(s.sigma, 1.0);
326        } else {
327            panic!("Expected CmaEs state");
328        }
329    }
330
331    #[test]
332    fn test_checkpoint_with_methods() {
333        let population: Vec<Individual<RealVector>> =
334            vec![Individual::new(RealVector::new(vec![1.0, 2.0]))];
335        let best = Individual::with_fitness(RealVector::new(vec![0.5, 0.5]), 10.0);
336
337        let checkpoint = Checkpoint::new(5, population)
338            .with_evaluations(500)
339            .with_best(best.clone())
340            .with_algorithm_state(AlgorithmState::SteadyState {
341                replacement_count: 10,
342            })
343            .with_metadata("run_id", "test123")
344            .with_rng_state(vec![1, 2, 3, 4]);
345
346        assert_eq!(checkpoint.evaluations, 500);
347        assert!(checkpoint.best.is_some());
348        assert_eq!(
349            checkpoint.metadata.get("run_id"),
350            Some(&"test123".to_string())
351        );
352        assert!(checkpoint.rng_state.is_some());
353    }
354
355    #[test]
356    fn test_checkpoint_with_hyperparameters() {
357        let population: Vec<Individual<RealVector>> = vec![];
358        let hp_state = HyperparameterState {
359            mutation_rate_posterior: Some((2.0, 8.0)),
360            crossover_prob_posterior: Some((5.0, 5.0)),
361            ..Default::default()
362        };
363
364        let checkpoint = Checkpoint::new(0, population).with_hyperparameter_state(hp_state);
365
366        assert!(checkpoint.hyperparameter_state.is_some());
367        let hp = checkpoint.hyperparameter_state.unwrap();
368        assert_eq!(hp.mutation_rate_posterior, Some((2.0, 8.0)));
369    }
370
371    #[test]
372    fn test_checkpoint_with_statistics() {
373        use crate::diagnostics::{GenerationStats, TimingStats};
374
375        let population: Vec<Individual<RealVector>> = vec![];
376        let stats = vec![
377            GenerationStats {
378                generation: 0,
379                evaluations: 100,
380                best_fitness: 10.0,
381                worst_fitness: 1.0,
382                mean_fitness: 5.0,
383                median_fitness: 5.0,
384                fitness_std: 2.0,
385                diversity: 0.5,
386                timing: TimingStats::default(),
387            },
388            GenerationStats {
389                generation: 1,
390                evaluations: 200,
391                best_fitness: 15.0,
392                worst_fitness: 2.0,
393                mean_fitness: 7.0,
394                median_fitness: 7.0,
395                fitness_std: 1.5,
396                diversity: 0.4,
397                timing: TimingStats::default(),
398            },
399        ];
400
401        let checkpoint = Checkpoint::new(2, population).with_statistics(stats.clone());
402
403        assert_eq!(checkpoint.statistics.len(), 2);
404    }
405
406    #[test]
407    fn test_checkpoint_version() {
408        let population: Vec<Individual<RealVector>> = vec![];
409        let checkpoint = Checkpoint::new(0, population);
410
411        assert_eq!(checkpoint.version(), CHECKPOINT_VERSION);
412    }
413
414    #[test]
415    fn test_checkpoint_builder_full() {
416        use crate::diagnostics::{GenerationStats, TimingStats};
417
418        let population: Vec<Individual<RealVector>> =
419            vec![Individual::new(RealVector::new(vec![1.0]))];
420        let best = Individual::with_fitness(RealVector::new(vec![0.0]), 100.0);
421        let hp_state = HyperparameterState::default();
422        let stats = vec![GenerationStats {
423            generation: 0,
424            evaluations: 100,
425            best_fitness: 100.0,
426            worst_fitness: 10.0,
427            mean_fitness: 50.0,
428            median_fitness: 50.0,
429            fitness_std: 10.0,
430            diversity: 0.5,
431            timing: TimingStats::default(),
432        }];
433
434        let checkpoint = CheckpointBuilder::new(10, population)
435            .evaluations(5000)
436            .best(best)
437            .algorithm_state(AlgorithmState::Nsga2 {
438                pareto_front_indices: vec![0, 1, 2],
439            })
440            .hyperparameters(hp_state)
441            .statistics(stats)
442            .metadata("version", "1.0")
443            .rng_state(vec![0, 1, 2, 3])
444            .build();
445
446        assert_eq!(checkpoint.generation, 10);
447        assert_eq!(checkpoint.evaluations, 5000);
448        assert!(checkpoint.best.is_some());
449        assert!(checkpoint.hyperparameter_state.is_some());
450        assert_eq!(checkpoint.statistics.len(), 1);
451        assert!(checkpoint.rng_state.is_some());
452    }
453
454    #[test]
455    fn test_algorithm_state_variants() {
456        // Test all algorithm state variants
457        let simple_ga = AlgorithmState::SimpleGA;
458        let steady_state = AlgorithmState::SteadyState {
459            replacement_count: 5,
460        };
461        let nsga2 = AlgorithmState::Nsga2 {
462            pareto_front_indices: vec![0, 1],
463        };
464        let hbga = AlgorithmState::Hbga {
465            population_params: vec![1.0, 2.0],
466            temperature: 1.5,
467        };
468        let island = AlgorithmState::Island {
469            island_populations: vec![vec![0, 1], vec![2, 3]],
470            migration_count: 3,
471        };
472        let interactive = AlgorithmState::Interactive {
473            aggregator_state: "{}".to_string(),
474            pending_evaluations: 10,
475            evaluation_mode: "pairwise".to_string(),
476        };
477        let custom = AlgorithmState::Custom("custom_state".to_string());
478
479        // Verify they're all different variants (pattern matching)
480        assert!(matches!(simple_ga, AlgorithmState::SimpleGA));
481        assert!(matches!(steady_state, AlgorithmState::SteadyState { .. }));
482        assert!(matches!(nsga2, AlgorithmState::Nsga2 { .. }));
483        assert!(matches!(hbga, AlgorithmState::Hbga { .. }));
484        assert!(matches!(island, AlgorithmState::Island { .. }));
485        assert!(matches!(interactive, AlgorithmState::Interactive { .. }));
486        assert!(matches!(custom, AlgorithmState::Custom(_)));
487    }
488
489    #[test]
490    fn test_hyperparameter_state_default() {
491        let hp = HyperparameterState::default();
492
493        assert!(hp.mutation_rate_posterior.is_none());
494        assert!(hp.crossover_prob_posterior.is_none());
495        assert!(hp.temperature_posterior.is_none());
496        assert!(hp.step_size_posteriors.is_empty());
497        assert!(hp.operator_weights.is_empty());
498        assert_eq!(hp.history_size, 100);
499    }
500}