1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use crate::diagnostics::GenerationStats;
9use crate::genome::traits::EvolutionaryGenome;
10use crate::population::individual::Individual;
11
12pub const CHECKPOINT_VERSION: u32 = 1;
14
15#[derive(Clone, Debug, Serialize, Deserialize)]
17#[serde(bound = "")]
18pub struct Checkpoint<G>
19where
20 G: Clone + Serialize + EvolutionaryGenome,
21{
22 pub version: u32,
24 pub generation: usize,
26 pub evaluations: usize,
28 pub population: Vec<Individual<G>>,
30 pub rng_state: Option<Vec<u8>>,
32 pub best: Option<Individual<G>>,
34 pub algorithm_state: AlgorithmState,
36 pub hyperparameter_state: Option<HyperparameterState>,
38 pub statistics: Vec<GenerationStats>,
40 pub metadata: HashMap<String, String>,
42}
43
44impl<G> Checkpoint<G>
45where
46 G: Clone + Serialize + EvolutionaryGenome,
47{
48 pub fn new(generation: usize, population: Vec<Individual<G>>) -> Self {
50 Self {
51 version: CHECKPOINT_VERSION,
52 generation,
53 evaluations: 0,
54 population,
55 rng_state: None,
56 best: None,
57 algorithm_state: AlgorithmState::SimpleGA,
58 hyperparameter_state: None,
59 statistics: Vec::new(),
60 metadata: HashMap::new(),
61 }
62 }
63
64 pub fn with_evaluations(mut self, evaluations: usize) -> Self {
66 self.evaluations = evaluations;
67 self
68 }
69
70 pub fn with_best(mut self, best: Individual<G>) -> Self {
72 self.best = Some(best);
73 self
74 }
75
76 pub fn with_algorithm_state(mut self, state: AlgorithmState) -> Self {
78 self.algorithm_state = state;
79 self
80 }
81
82 pub fn with_hyperparameter_state(mut self, state: HyperparameterState) -> Self {
84 self.hyperparameter_state = Some(state);
85 self
86 }
87
88 pub fn with_statistics(mut self, stats: Vec<GenerationStats>) -> Self {
90 self.statistics = stats;
91 self
92 }
93
94 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
96 self.metadata.insert(key.into(), value.into());
97 self
98 }
99
100 pub fn with_rng_state(mut self, state: Vec<u8>) -> Self {
102 self.rng_state = Some(state);
103 self
104 }
105
106 pub fn is_compatible(&self) -> bool {
108 self.version <= CHECKPOINT_VERSION
109 }
110
111 pub fn version(&self) -> u32 {
113 self.version
114 }
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize)]
119pub enum AlgorithmState {
120 SimpleGA,
122 SteadyState { replacement_count: usize },
124 CmaEs(CmaEsCheckpointState),
126 Nsga2 { pareto_front_indices: Vec<usize> },
128 Hbga {
130 population_params: Vec<f64>,
131 temperature: f64,
132 },
133 Island {
135 island_populations: Vec<Vec<usize>>,
136 migration_count: usize,
137 },
138 Interactive {
140 aggregator_state: String,
142 pending_evaluations: usize,
144 evaluation_mode: String,
146 },
147 Custom(String),
149}
150
151#[derive(Clone, Debug, Serialize, Deserialize)]
153pub struct CmaEsCheckpointState {
154 pub mean: Vec<f64>,
156 pub sigma: f64,
158 pub covariance: Vec<f64>,
160 pub path_sigma: Vec<f64>,
162 pub path_c: Vec<f64>,
164 pub dimension: usize,
166}
167
168#[derive(Clone, Debug, Serialize, Deserialize)]
170pub struct HyperparameterState {
171 pub mutation_rate_posterior: Option<(f64, f64)>,
173 pub crossover_prob_posterior: Option<(f64, f64)>,
175 pub temperature_posterior: Option<(f64, f64)>,
177 pub step_size_posteriors: Vec<(f64, f64)>,
179 pub operator_weights: Vec<f64>,
181 pub history_size: usize,
183}
184
185impl Default for HyperparameterState {
186 fn default() -> Self {
187 Self {
188 mutation_rate_posterior: None,
189 crossover_prob_posterior: None,
190 temperature_posterior: None,
191 step_size_posteriors: Vec::new(),
192 operator_weights: Vec::new(),
193 history_size: 100,
194 }
195 }
196}
197
198pub struct CheckpointBuilder<G>
200where
201 G: Clone + Serialize + EvolutionaryGenome,
202{
203 checkpoint: Checkpoint<G>,
204}
205
206impl<G> CheckpointBuilder<G>
207where
208 G: Clone + Serialize + EvolutionaryGenome,
209{
210 pub fn new(generation: usize, population: Vec<Individual<G>>) -> Self {
212 Self {
213 checkpoint: Checkpoint::new(generation, population),
214 }
215 }
216
217 pub fn evaluations(mut self, count: usize) -> Self {
219 self.checkpoint.evaluations = count;
220 self
221 }
222
223 pub fn best(mut self, individual: Individual<G>) -> Self {
225 self.checkpoint.best = Some(individual);
226 self
227 }
228
229 pub fn algorithm_state(mut self, state: AlgorithmState) -> Self {
231 self.checkpoint.algorithm_state = state;
232 self
233 }
234
235 pub fn hyperparameters(mut self, state: HyperparameterState) -> Self {
237 self.checkpoint.hyperparameter_state = Some(state);
238 self
239 }
240
241 pub fn statistics(mut self, stats: Vec<GenerationStats>) -> Self {
243 self.checkpoint.statistics = stats;
244 self
245 }
246
247 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
249 self.checkpoint.metadata.insert(key.into(), value.into());
250 self
251 }
252
253 pub fn rng_state(mut self, state: Vec<u8>) -> Self {
255 self.checkpoint.rng_state = Some(state);
256 self
257 }
258
259 pub fn build(self) -> Checkpoint<G> {
261 self.checkpoint
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::genome::real_vector::RealVector;
269
270 #[test]
271 fn test_checkpoint_creation() {
272 let population: Vec<Individual<RealVector>> = vec![
273 Individual::new(RealVector::new(vec![1.0, 2.0])),
274 Individual::new(RealVector::new(vec![3.0, 4.0])),
275 ];
276
277 let checkpoint = Checkpoint::new(10, population.clone());
278
279 assert_eq!(checkpoint.version, CHECKPOINT_VERSION);
280 assert_eq!(checkpoint.generation, 10);
281 assert_eq!(checkpoint.population.len(), 2);
282 }
283
284 #[test]
285 fn test_checkpoint_builder() {
286 let population: Vec<Individual<RealVector>> =
287 vec![Individual::new(RealVector::new(vec![1.0]))];
288
289 let checkpoint = CheckpointBuilder::new(5, population)
290 .evaluations(1000)
291 .algorithm_state(AlgorithmState::SimpleGA)
292 .metadata("experiment", "test_run")
293 .build();
294
295 assert_eq!(checkpoint.generation, 5);
296 assert_eq!(checkpoint.evaluations, 1000);
297 assert_eq!(
298 checkpoint.metadata.get("experiment"),
299 Some(&"test_run".to_string())
300 );
301 }
302
303 #[test]
304 fn test_checkpoint_compatibility() {
305 let population: Vec<Individual<RealVector>> = vec![];
306 let checkpoint = Checkpoint::new(0, population);
307
308 assert!(checkpoint.is_compatible());
309 }
310
311 #[test]
312 fn test_cmaes_checkpoint_state() {
313 let state = CmaEsCheckpointState {
314 mean: vec![0.0, 0.0],
315 sigma: 1.0,
316 covariance: vec![1.0, 0.0, 0.0, 1.0],
317 path_sigma: vec![0.0, 0.0],
318 path_c: vec![0.0, 0.0],
319 dimension: 2,
320 };
321
322 let alg_state = AlgorithmState::CmaEs(state);
323 if let AlgorithmState::CmaEs(s) = alg_state {
324 assert_eq!(s.dimension, 2);
325 assert_eq!(s.sigma, 1.0);
326 } else {
327 panic!("Expected CmaEs state");
328 }
329 }
330
331 #[test]
332 fn test_checkpoint_with_methods() {
333 let population: Vec<Individual<RealVector>> =
334 vec![Individual::new(RealVector::new(vec![1.0, 2.0]))];
335 let best = Individual::with_fitness(RealVector::new(vec![0.5, 0.5]), 10.0);
336
337 let checkpoint = Checkpoint::new(5, population)
338 .with_evaluations(500)
339 .with_best(best.clone())
340 .with_algorithm_state(AlgorithmState::SteadyState {
341 replacement_count: 10,
342 })
343 .with_metadata("run_id", "test123")
344 .with_rng_state(vec![1, 2, 3, 4]);
345
346 assert_eq!(checkpoint.evaluations, 500);
347 assert!(checkpoint.best.is_some());
348 assert_eq!(
349 checkpoint.metadata.get("run_id"),
350 Some(&"test123".to_string())
351 );
352 assert!(checkpoint.rng_state.is_some());
353 }
354
355 #[test]
356 fn test_checkpoint_with_hyperparameters() {
357 let population: Vec<Individual<RealVector>> = vec![];
358 let hp_state = HyperparameterState {
359 mutation_rate_posterior: Some((2.0, 8.0)),
360 crossover_prob_posterior: Some((5.0, 5.0)),
361 ..Default::default()
362 };
363
364 let checkpoint = Checkpoint::new(0, population).with_hyperparameter_state(hp_state);
365
366 assert!(checkpoint.hyperparameter_state.is_some());
367 let hp = checkpoint.hyperparameter_state.unwrap();
368 assert_eq!(hp.mutation_rate_posterior, Some((2.0, 8.0)));
369 }
370
371 #[test]
372 fn test_checkpoint_with_statistics() {
373 use crate::diagnostics::{GenerationStats, TimingStats};
374
375 let population: Vec<Individual<RealVector>> = vec![];
376 let stats = vec![
377 GenerationStats {
378 generation: 0,
379 evaluations: 100,
380 best_fitness: 10.0,
381 worst_fitness: 1.0,
382 mean_fitness: 5.0,
383 median_fitness: 5.0,
384 fitness_std: 2.0,
385 diversity: 0.5,
386 timing: TimingStats::default(),
387 },
388 GenerationStats {
389 generation: 1,
390 evaluations: 200,
391 best_fitness: 15.0,
392 worst_fitness: 2.0,
393 mean_fitness: 7.0,
394 median_fitness: 7.0,
395 fitness_std: 1.5,
396 diversity: 0.4,
397 timing: TimingStats::default(),
398 },
399 ];
400
401 let checkpoint = Checkpoint::new(2, population).with_statistics(stats.clone());
402
403 assert_eq!(checkpoint.statistics.len(), 2);
404 }
405
406 #[test]
407 fn test_checkpoint_version() {
408 let population: Vec<Individual<RealVector>> = vec![];
409 let checkpoint = Checkpoint::new(0, population);
410
411 assert_eq!(checkpoint.version(), CHECKPOINT_VERSION);
412 }
413
414 #[test]
415 fn test_checkpoint_builder_full() {
416 use crate::diagnostics::{GenerationStats, TimingStats};
417
418 let population: Vec<Individual<RealVector>> =
419 vec![Individual::new(RealVector::new(vec![1.0]))];
420 let best = Individual::with_fitness(RealVector::new(vec![0.0]), 100.0);
421 let hp_state = HyperparameterState::default();
422 let stats = vec![GenerationStats {
423 generation: 0,
424 evaluations: 100,
425 best_fitness: 100.0,
426 worst_fitness: 10.0,
427 mean_fitness: 50.0,
428 median_fitness: 50.0,
429 fitness_std: 10.0,
430 diversity: 0.5,
431 timing: TimingStats::default(),
432 }];
433
434 let checkpoint = CheckpointBuilder::new(10, population)
435 .evaluations(5000)
436 .best(best)
437 .algorithm_state(AlgorithmState::Nsga2 {
438 pareto_front_indices: vec![0, 1, 2],
439 })
440 .hyperparameters(hp_state)
441 .statistics(stats)
442 .metadata("version", "1.0")
443 .rng_state(vec![0, 1, 2, 3])
444 .build();
445
446 assert_eq!(checkpoint.generation, 10);
447 assert_eq!(checkpoint.evaluations, 5000);
448 assert!(checkpoint.best.is_some());
449 assert!(checkpoint.hyperparameter_state.is_some());
450 assert_eq!(checkpoint.statistics.len(), 1);
451 assert!(checkpoint.rng_state.is_some());
452 }
453
454 #[test]
455 fn test_algorithm_state_variants() {
456 let simple_ga = AlgorithmState::SimpleGA;
458 let steady_state = AlgorithmState::SteadyState {
459 replacement_count: 5,
460 };
461 let nsga2 = AlgorithmState::Nsga2 {
462 pareto_front_indices: vec![0, 1],
463 };
464 let hbga = AlgorithmState::Hbga {
465 population_params: vec![1.0, 2.0],
466 temperature: 1.5,
467 };
468 let island = AlgorithmState::Island {
469 island_populations: vec![vec![0, 1], vec![2, 3]],
470 migration_count: 3,
471 };
472 let interactive = AlgorithmState::Interactive {
473 aggregator_state: "{}".to_string(),
474 pending_evaluations: 10,
475 evaluation_mode: "pairwise".to_string(),
476 };
477 let custom = AlgorithmState::Custom("custom_state".to_string());
478
479 assert!(matches!(simple_ga, AlgorithmState::SimpleGA));
481 assert!(matches!(steady_state, AlgorithmState::SteadyState { .. }));
482 assert!(matches!(nsga2, AlgorithmState::Nsga2 { .. }));
483 assert!(matches!(hbga, AlgorithmState::Hbga { .. }));
484 assert!(matches!(island, AlgorithmState::Island { .. }));
485 assert!(matches!(interactive, AlgorithmState::Interactive { .. }));
486 assert!(matches!(custom, AlgorithmState::Custom(_)));
487 }
488
489 #[test]
490 fn test_hyperparameter_state_default() {
491 let hp = HyperparameterState::default();
492
493 assert!(hp.mutation_rate_posterior.is_none());
494 assert!(hp.crossover_prob_posterior.is_none());
495 assert!(hp.temperature_posterior.is_none());
496 assert!(hp.step_size_posteriors.is_empty());
497 assert!(hp.operator_weights.is_empty());
498 assert_eq!(hp.history_size, 100);
499 }
500}