1use serde::{de::DeserializeOwned, Serialize};
6
7#[cfg(feature = "checkpoint")]
8use std::fs::File;
9#[cfg(feature = "checkpoint")]
10use std::io::{BufReader, BufWriter, Read, Write};
11#[cfg(feature = "checkpoint")]
12use std::path::Path;
13
14use super::state::{Checkpoint, CHECKPOINT_VERSION};
15use crate::error::CheckpointError;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum CheckpointFormat {
20 Json,
22 Binary,
24 CompressedBinary,
26}
27
28impl Default for CheckpointFormat {
29 fn default() -> Self {
30 Self::Binary
31 }
32}
33
34#[cfg(feature = "checkpoint")]
36pub fn save_checkpoint<G>(
37 checkpoint: &Checkpoint<G>,
38 path: impl AsRef<Path>,
39 format: CheckpointFormat,
40) -> Result<(), CheckpointError>
41where
42 G: Clone + Serialize + crate::genome::traits::EvolutionaryGenome,
43{
44 let path = path.as_ref();
45 let file = File::create(path)?;
46 let mut writer = BufWriter::new(file);
47
48 match format {
49 CheckpointFormat::Json => {
50 serde_json::to_writer_pretty(&mut writer, checkpoint)
51 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
52 }
53 CheckpointFormat::Binary => {
54 writer.write_all(&CHECKPOINT_VERSION.to_le_bytes())?;
56 writer.write_all(b"FEVO")?;
58 bincode::serialize_into(&mut writer, checkpoint)
60 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
61 }
62 CheckpointFormat::CompressedBinary => {
63 writer.write_all(&CHECKPOINT_VERSION.to_le_bytes())?;
65 writer.write_all(b"FEVC")?; let bytes = bincode::serialize(checkpoint)
68 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
69 let compressed = compress_data(&bytes);
71 writer.write_all(&(compressed.len() as u64).to_le_bytes())?;
73 writer.write_all(&compressed)?;
74 }
75 }
76
77 writer.flush()?;
78 Ok(())
79}
80
81#[cfg(feature = "checkpoint")]
83pub fn load_checkpoint<G>(path: impl AsRef<Path>) -> Result<Checkpoint<G>, CheckpointError>
84where
85 G: Clone + Serialize + DeserializeOwned + crate::genome::traits::EvolutionaryGenome,
86{
87 let path = path.as_ref();
88 if !path.exists() {
89 return Err(CheckpointError::NotFound(path.display().to_string()));
90 }
91
92 let file = File::open(path)?;
93 let mut reader = BufReader::new(file);
94
95 let mut header = [0u8; 8];
97 reader.read_exact(&mut header)?;
98
99 if &header[4..8] == b"FEVO" {
101 let version = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
103 if version > CHECKPOINT_VERSION {
104 return Err(CheckpointError::VersionMismatch {
105 expected: CHECKPOINT_VERSION,
106 found: version,
107 });
108 }
109
110 bincode::deserialize_from(&mut reader)
111 .map_err(|e| CheckpointError::Deserialization(e.to_string()))
112 } else if &header[4..8] == b"FEVC" {
113 let version = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
115 if version > CHECKPOINT_VERSION {
116 return Err(CheckpointError::VersionMismatch {
117 expected: CHECKPOINT_VERSION,
118 found: version,
119 });
120 }
121
122 let mut len_bytes = [0u8; 8];
124 reader.read_exact(&mut len_bytes)?;
125 let compressed_len = u64::from_le_bytes(len_bytes) as usize;
126
127 let mut compressed = vec![0u8; compressed_len];
129 reader.read_exact(&mut compressed)?;
130
131 let decompressed =
133 decompress_data(&compressed).map_err(|e| CheckpointError::Corrupted(e))?;
134
135 bincode::deserialize(&decompressed)
136 .map_err(|e| CheckpointError::Deserialization(e.to_string()))
137 } else {
138 drop(reader);
140 let file = File::open(path)?;
141 let reader = BufReader::new(file);
142
143 serde_json::from_reader(reader).map_err(|e| CheckpointError::Deserialization(e.to_string()))
144 }
145}
146
147#[cfg(feature = "checkpoint")]
149fn compress_data(data: &[u8]) -> Vec<u8> {
150 if data.is_empty() {
151 return Vec::new();
152 }
153
154 let mut compressed = Vec::with_capacity(data.len());
155 let mut i = 0;
156
157 while i < data.len() {
158 let byte = data[i];
159 let mut count = 1u8;
160
161 while i + (count as usize) < data.len() && data[i + (count as usize)] == byte && count < 255
163 {
164 count += 1;
165 }
166
167 if count >= 4 || byte == 0xFF {
168 compressed.push(0xFF);
170 compressed.push(count);
171 compressed.push(byte);
172 } else {
173 for _ in 0..count {
175 if byte == 0xFF {
176 compressed.push(0xFF);
177 compressed.push(1);
178 compressed.push(0xFF);
179 } else {
180 compressed.push(byte);
181 }
182 }
183 }
184
185 i += count as usize;
186 }
187
188 compressed
189}
190
191#[cfg(feature = "checkpoint")]
193fn decompress_data(data: &[u8]) -> Result<Vec<u8>, String> {
194 let mut decompressed = Vec::new();
195 let mut i = 0;
196
197 while i < data.len() {
198 if data[i] == 0xFF {
199 if i + 2 >= data.len() {
200 return Err("Truncated RLE sequence".to_string());
201 }
202 let count = data[i + 1] as usize;
203 let byte = data[i + 2];
204 for _ in 0..count {
205 decompressed.push(byte);
206 }
207 i += 3;
208 } else {
209 decompressed.push(data[i]);
210 i += 1;
211 }
212 }
213
214 Ok(decompressed)
215}
216
217#[cfg(feature = "checkpoint")]
219pub struct CheckpointManager {
220 pub directory: std::path::PathBuf,
222 pub base_name: String,
224 pub format: CheckpointFormat,
226 pub keep_n: usize,
228 pub interval: usize,
230 current_index: usize,
232}
233
234#[cfg(feature = "checkpoint")]
235impl CheckpointManager {
236 pub fn new(directory: impl Into<std::path::PathBuf>, base_name: impl Into<String>) -> Self {
238 Self {
239 directory: directory.into(),
240 base_name: base_name.into(),
241 format: CheckpointFormat::Binary,
242 keep_n: 3,
243 interval: 100,
244 current_index: 0,
245 }
246 }
247
248 pub fn with_format(mut self, format: CheckpointFormat) -> Self {
250 self.format = format;
251 self
252 }
253
254 pub fn keep(mut self, n: usize) -> Self {
256 self.keep_n = n;
257 self
258 }
259
260 pub fn every(mut self, generations: usize) -> Self {
262 self.interval = generations;
263 self
264 }
265
266 pub fn should_save(&self, generation: usize) -> bool {
268 generation > 0 && generation.is_multiple_of(self.interval)
269 }
270
271 pub fn current_path(&self) -> std::path::PathBuf {
273 let extension = match self.format {
274 CheckpointFormat::Json => "json",
275 CheckpointFormat::Binary | CheckpointFormat::CompressedBinary => "ckpt",
276 };
277 self.directory.join(format!(
278 "{}_{:04}.{}",
279 self.base_name, self.current_index, extension
280 ))
281 }
282
283 pub fn save<G>(&mut self, checkpoint: &Checkpoint<G>) -> Result<(), CheckpointError>
285 where
286 G: Clone + Serialize + crate::genome::traits::EvolutionaryGenome,
287 {
288 std::fs::create_dir_all(&self.directory)?;
290
291 let path = self.current_path();
293 save_checkpoint(checkpoint, &path, self.format)?;
294
295 self.current_index += 1;
297 if self.current_index > self.keep_n {
298 let old_index = self.current_index - self.keep_n - 1;
299 let extension = match self.format {
300 CheckpointFormat::Json => "json",
301 CheckpointFormat::Binary | CheckpointFormat::CompressedBinary => "ckpt",
302 };
303 let old_path = self
304 .directory
305 .join(format!("{}_{:04}.{}", self.base_name, old_index, extension));
306 let _ = std::fs::remove_file(old_path); }
308
309 Ok(())
310 }
311
312 pub fn load_latest<G>(&self) -> Result<Option<Checkpoint<G>>, CheckpointError>
314 where
315 G: Clone + Serialize + DeserializeOwned + crate::genome::traits::EvolutionaryGenome,
316 {
317 let extension = match self.format {
318 CheckpointFormat::Json => "json",
319 CheckpointFormat::Binary | CheckpointFormat::CompressedBinary => "ckpt",
320 };
321
322 let _pattern = format!("{}_*.{}", self.base_name, extension);
324 let mut checkpoints: Vec<_> = std::fs::read_dir(&self.directory)?
325 .filter_map(|e| e.ok())
326 .filter(|e| e.file_name().to_string_lossy().starts_with(&self.base_name))
327 .collect();
328
329 if checkpoints.is_empty() {
330 return Ok(None);
331 }
332
333 checkpoints.sort_by(|a, b| {
336 let name_a = a.file_name().to_string_lossy().to_string();
337 let name_b = b.file_name().to_string_lossy().to_string();
338 name_b.cmp(&name_a)
340 });
341
342 for entry in checkpoints {
344 match load_checkpoint(entry.path()) {
345 Ok(checkpoint) => return Ok(Some(checkpoint)),
346 Err(_) => continue, }
348 }
349
350 Ok(None)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::genome::real_vector::RealVector;
358 use crate::population::individual::Individual;
359 use tempfile::tempdir;
360
361 #[test]
362 fn test_save_load_json() {
363 let dir = tempdir().unwrap();
364 let path = dir.path().join("test.json");
365
366 let population: Vec<Individual<RealVector>> = vec![
367 Individual::new(RealVector::new(vec![1.0, 2.0])),
368 Individual::new(RealVector::new(vec![3.0, 4.0])),
369 ];
370 let checkpoint = Checkpoint::new(10, population);
371
372 save_checkpoint(&checkpoint, &path, CheckpointFormat::Json).unwrap();
373 let loaded: Checkpoint<RealVector> = load_checkpoint(&path).unwrap();
374
375 assert_eq!(loaded.generation, 10);
376 assert_eq!(loaded.population.len(), 2);
377 }
378
379 #[test]
380 fn test_save_load_binary() {
381 let dir = tempdir().unwrap();
382 let path = dir.path().join("test.ckpt");
383
384 let population: Vec<Individual<RealVector>> =
385 vec![Individual::new(RealVector::new(vec![1.0, 2.0, 3.0]))];
386 let checkpoint = Checkpoint::new(5, population)
387 .with_evaluations(500)
388 .with_metadata("test", "value");
389
390 save_checkpoint(&checkpoint, &path, CheckpointFormat::Binary).unwrap();
391 let loaded: Checkpoint<RealVector> = load_checkpoint(&path).unwrap();
392
393 assert_eq!(loaded.generation, 5);
394 assert_eq!(loaded.evaluations, 500);
395 assert_eq!(loaded.metadata.get("test"), Some(&"value".to_string()));
396 }
397
398 #[test]
399 fn test_save_load_compressed() {
400 let dir = tempdir().unwrap();
401 let path = dir.path().join("test_compressed.ckpt");
402
403 let population: Vec<Individual<RealVector>> = (0..100)
405 .map(|i| Individual::new(RealVector::new(vec![i as f64; 10])))
406 .collect();
407 let checkpoint = Checkpoint::new(100, population);
408
409 save_checkpoint(&checkpoint, &path, CheckpointFormat::CompressedBinary).unwrap();
410 let loaded: Checkpoint<RealVector> = load_checkpoint(&path).unwrap();
411
412 assert_eq!(loaded.generation, 100);
413 assert_eq!(loaded.population.len(), 100);
414 }
415
416 #[test]
417 fn test_compression_decompression() {
418 let original = vec![0u8, 0, 0, 0, 0, 1, 2, 3, 3, 3, 3, 3, 3, 4, 5];
419 let compressed = compress_data(&original);
420 let decompressed = decompress_data(&compressed).unwrap();
421 assert_eq!(original, decompressed);
422 }
423
424 #[test]
425 fn test_checkpoint_manager() {
426 let dir = tempdir().unwrap();
427 let mut manager = CheckpointManager::new(dir.path(), "evolution")
428 .with_format(CheckpointFormat::Binary)
429 .keep(2)
430 .every(10);
431
432 for gen in [10, 20, 30, 40] {
434 let population: Vec<Individual<RealVector>> =
435 vec![Individual::new(RealVector::new(vec![gen as f64]))];
436 let checkpoint = Checkpoint::new(gen, population);
437 manager.save(&checkpoint).unwrap();
438 }
439
440 let loaded: Option<Checkpoint<RealVector>> = manager.load_latest().unwrap();
442 assert!(loaded.is_some());
443 assert_eq!(loaded.unwrap().generation, 40);
444 }
445
446 #[test]
447 fn test_version_check() {
448 let population: Vec<Individual<RealVector>> = vec![];
449 let checkpoint = Checkpoint::new(0, population);
450 assert!(checkpoint.is_compatible());
451 }
452}