Skip to main content

fugue_evo/checkpoint/
recovery.rs

1//! Checkpoint recovery and persistence
2//!
3//! Provides serialization to/from files with compression and versioning.
4
5use serde::{de::DeserializeOwned, Serialize};
6
7#[cfg(feature = "checkpoint")]
8use std::fs::File;
9#[cfg(feature = "checkpoint")]
10use std::io::{BufReader, BufWriter, Read, Write};
11#[cfg(feature = "checkpoint")]
12use std::path::Path;
13
14use super::state::{Checkpoint, CHECKPOINT_VERSION};
15use crate::error::CheckpointError;
16
17/// Format for checkpoint serialization
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum CheckpointFormat {
20    /// JSON format (human-readable, larger)
21    Json,
22    /// Binary format (compact, fast)
23    Binary,
24    /// Compressed binary (smallest, slower)
25    CompressedBinary,
26}
27
28impl Default for CheckpointFormat {
29    fn default() -> Self {
30        Self::Binary
31    }
32}
33
34/// Save a checkpoint to a file
35#[cfg(feature = "checkpoint")]
36pub fn save_checkpoint<G>(
37    checkpoint: &Checkpoint<G>,
38    path: impl AsRef<Path>,
39    format: CheckpointFormat,
40) -> Result<(), CheckpointError>
41where
42    G: Clone + Serialize + crate::genome::traits::EvolutionaryGenome,
43{
44    let path = path.as_ref();
45    let file = File::create(path)?;
46    let mut writer = BufWriter::new(file);
47
48    match format {
49        CheckpointFormat::Json => {
50            serde_json::to_writer_pretty(&mut writer, checkpoint)
51                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
52        }
53        CheckpointFormat::Binary => {
54            // Write version header first
55            writer.write_all(&CHECKPOINT_VERSION.to_le_bytes())?;
56            // Write magic bytes for format identification
57            writer.write_all(b"FEVO")?;
58            // Serialize with bincode
59            bincode::serialize_into(&mut writer, checkpoint)
60                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
61        }
62        CheckpointFormat::CompressedBinary => {
63            // Write version and magic
64            writer.write_all(&CHECKPOINT_VERSION.to_le_bytes())?;
65            writer.write_all(b"FEVC")?; // C for compressed
66                                        // Serialize to bytes first
67            let bytes = bincode::serialize(checkpoint)
68                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
69            // Compress with simple RLE-like compression
70            let compressed = compress_data(&bytes);
71            // Write length and data
72            writer.write_all(&(compressed.len() as u64).to_le_bytes())?;
73            writer.write_all(&compressed)?;
74        }
75    }
76
77    writer.flush()?;
78    Ok(())
79}
80
81/// Load a checkpoint from a file
82#[cfg(feature = "checkpoint")]
83pub fn load_checkpoint<G>(path: impl AsRef<Path>) -> Result<Checkpoint<G>, CheckpointError>
84where
85    G: Clone + Serialize + DeserializeOwned + crate::genome::traits::EvolutionaryGenome,
86{
87    let path = path.as_ref();
88    if !path.exists() {
89        return Err(CheckpointError::NotFound(path.display().to_string()));
90    }
91
92    let file = File::open(path)?;
93    let mut reader = BufReader::new(file);
94
95    // Try to detect format by reading first bytes
96    let mut header = [0u8; 8];
97    reader.read_exact(&mut header)?;
98
99    // Check for binary format magic
100    if &header[4..8] == b"FEVO" {
101        // Binary format
102        let version = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
103        if version > CHECKPOINT_VERSION {
104            return Err(CheckpointError::VersionMismatch {
105                expected: CHECKPOINT_VERSION,
106                found: version,
107            });
108        }
109
110        bincode::deserialize_from(&mut reader)
111            .map_err(|e| CheckpointError::Deserialization(e.to_string()))
112    } else if &header[4..8] == b"FEVC" {
113        // Compressed binary format
114        let version = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
115        if version > CHECKPOINT_VERSION {
116            return Err(CheckpointError::VersionMismatch {
117                expected: CHECKPOINT_VERSION,
118                found: version,
119            });
120        }
121
122        // Read length
123        let mut len_bytes = [0u8; 8];
124        reader.read_exact(&mut len_bytes)?;
125        let compressed_len = u64::from_le_bytes(len_bytes) as usize;
126
127        // Read compressed data
128        let mut compressed = vec![0u8; compressed_len];
129        reader.read_exact(&mut compressed)?;
130
131        // Decompress
132        let decompressed =
133            decompress_data(&compressed).map_err(|e| CheckpointError::Corrupted(e))?;
134
135        bincode::deserialize(&decompressed)
136            .map_err(|e| CheckpointError::Deserialization(e.to_string()))
137    } else {
138        // Try JSON format - need to re-read from start
139        drop(reader);
140        let file = File::open(path)?;
141        let reader = BufReader::new(file);
142
143        serde_json::from_reader(reader).map_err(|e| CheckpointError::Deserialization(e.to_string()))
144    }
145}
146
147/// Simple compression using run-length encoding for repeated bytes
148#[cfg(feature = "checkpoint")]
149fn compress_data(data: &[u8]) -> Vec<u8> {
150    if data.is_empty() {
151        return Vec::new();
152    }
153
154    let mut compressed = Vec::with_capacity(data.len());
155    let mut i = 0;
156
157    while i < data.len() {
158        let byte = data[i];
159        let mut count = 1u8;
160
161        // Count consecutive identical bytes (max 255)
162        while i + (count as usize) < data.len() && data[i + (count as usize)] == byte && count < 255
163        {
164            count += 1;
165        }
166
167        if count >= 4 || byte == 0xFF {
168            // Use RLE: 0xFF, count, byte
169            compressed.push(0xFF);
170            compressed.push(count);
171            compressed.push(byte);
172        } else {
173            // Store literally
174            for _ in 0..count {
175                if byte == 0xFF {
176                    compressed.push(0xFF);
177                    compressed.push(1);
178                    compressed.push(0xFF);
179                } else {
180                    compressed.push(byte);
181                }
182            }
183        }
184
185        i += count as usize;
186    }
187
188    compressed
189}
190
191/// Decompress RLE-encoded data
192#[cfg(feature = "checkpoint")]
193fn decompress_data(data: &[u8]) -> Result<Vec<u8>, String> {
194    let mut decompressed = Vec::new();
195    let mut i = 0;
196
197    while i < data.len() {
198        if data[i] == 0xFF {
199            if i + 2 >= data.len() {
200                return Err("Truncated RLE sequence".to_string());
201            }
202            let count = data[i + 1] as usize;
203            let byte = data[i + 2];
204            for _ in 0..count {
205                decompressed.push(byte);
206            }
207            i += 3;
208        } else {
209            decompressed.push(data[i]);
210            i += 1;
211        }
212    }
213
214    Ok(decompressed)
215}
216
217/// Checkpoint manager for automatic saving
218#[cfg(feature = "checkpoint")]
219pub struct CheckpointManager {
220    /// Directory for checkpoint files
221    pub directory: std::path::PathBuf,
222    /// Base filename for checkpoints
223    pub base_name: String,
224    /// Serialization format
225    pub format: CheckpointFormat,
226    /// How many checkpoints to keep
227    pub keep_n: usize,
228    /// Save interval (generations)
229    pub interval: usize,
230    /// Current checkpoint index
231    current_index: usize,
232}
233
234#[cfg(feature = "checkpoint")]
235impl CheckpointManager {
236    /// Create a new checkpoint manager
237    pub fn new(directory: impl Into<std::path::PathBuf>, base_name: impl Into<String>) -> Self {
238        Self {
239            directory: directory.into(),
240            base_name: base_name.into(),
241            format: CheckpointFormat::Binary,
242            keep_n: 3,
243            interval: 100,
244            current_index: 0,
245        }
246    }
247
248    /// Set the serialization format
249    pub fn with_format(mut self, format: CheckpointFormat) -> Self {
250        self.format = format;
251        self
252    }
253
254    /// Set how many checkpoints to keep
255    pub fn keep(mut self, n: usize) -> Self {
256        self.keep_n = n;
257        self
258    }
259
260    /// Set the save interval
261    pub fn every(mut self, generations: usize) -> Self {
262        self.interval = generations;
263        self
264    }
265
266    /// Check if a checkpoint should be saved at this generation
267    pub fn should_save(&self, generation: usize) -> bool {
268        generation > 0 && generation.is_multiple_of(self.interval)
269    }
270
271    /// Get the path for the current checkpoint
272    pub fn current_path(&self) -> std::path::PathBuf {
273        let extension = match self.format {
274            CheckpointFormat::Json => "json",
275            CheckpointFormat::Binary | CheckpointFormat::CompressedBinary => "ckpt",
276        };
277        self.directory.join(format!(
278            "{}_{:04}.{}",
279            self.base_name, self.current_index, extension
280        ))
281    }
282
283    /// Save a checkpoint and rotate old ones
284    pub fn save<G>(&mut self, checkpoint: &Checkpoint<G>) -> Result<(), CheckpointError>
285    where
286        G: Clone + Serialize + crate::genome::traits::EvolutionaryGenome,
287    {
288        // Ensure directory exists
289        std::fs::create_dir_all(&self.directory)?;
290
291        // Save current checkpoint
292        let path = self.current_path();
293        save_checkpoint(checkpoint, &path, self.format)?;
294
295        // Rotate: remove old checkpoints
296        self.current_index += 1;
297        if self.current_index > self.keep_n {
298            let old_index = self.current_index - self.keep_n - 1;
299            let extension = match self.format {
300                CheckpointFormat::Json => "json",
301                CheckpointFormat::Binary | CheckpointFormat::CompressedBinary => "ckpt",
302            };
303            let old_path = self
304                .directory
305                .join(format!("{}_{:04}.{}", self.base_name, old_index, extension));
306            let _ = std::fs::remove_file(old_path); // Ignore errors
307        }
308
309        Ok(())
310    }
311
312    /// Find and load the latest checkpoint
313    pub fn load_latest<G>(&self) -> Result<Option<Checkpoint<G>>, CheckpointError>
314    where
315        G: Clone + Serialize + DeserializeOwned + crate::genome::traits::EvolutionaryGenome,
316    {
317        let extension = match self.format {
318            CheckpointFormat::Json => "json",
319            CheckpointFormat::Binary | CheckpointFormat::CompressedBinary => "ckpt",
320        };
321
322        // Find all checkpoint files
323        let _pattern = format!("{}_*.{}", self.base_name, extension);
324        let mut checkpoints: Vec<_> = std::fs::read_dir(&self.directory)?
325            .filter_map(|e| e.ok())
326            .filter(|e| e.file_name().to_string_lossy().starts_with(&self.base_name))
327            .collect();
328
329        if checkpoints.is_empty() {
330            return Ok(None);
331        }
332
333        // Sort by filename index (newest first) - more deterministic than modification time
334        // Filenames are formatted as {base_name}_{index:04}.{ext}
335        checkpoints.sort_by(|a, b| {
336            let name_a = a.file_name().to_string_lossy().to_string();
337            let name_b = b.file_name().to_string_lossy().to_string();
338            // Compare in reverse order to get newest first
339            name_b.cmp(&name_a)
340        });
341
342        // Try to load the newest checkpoint
343        for entry in checkpoints {
344            match load_checkpoint(entry.path()) {
345                Ok(checkpoint) => return Ok(Some(checkpoint)),
346                Err(_) => continue, // Try next if corrupted
347            }
348        }
349
350        Ok(None)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::genome::real_vector::RealVector;
358    use crate::population::individual::Individual;
359    use tempfile::tempdir;
360
361    #[test]
362    fn test_save_load_json() {
363        let dir = tempdir().unwrap();
364        let path = dir.path().join("test.json");
365
366        let population: Vec<Individual<RealVector>> = vec![
367            Individual::new(RealVector::new(vec![1.0, 2.0])),
368            Individual::new(RealVector::new(vec![3.0, 4.0])),
369        ];
370        let checkpoint = Checkpoint::new(10, population);
371
372        save_checkpoint(&checkpoint, &path, CheckpointFormat::Json).unwrap();
373        let loaded: Checkpoint<RealVector> = load_checkpoint(&path).unwrap();
374
375        assert_eq!(loaded.generation, 10);
376        assert_eq!(loaded.population.len(), 2);
377    }
378
379    #[test]
380    fn test_save_load_binary() {
381        let dir = tempdir().unwrap();
382        let path = dir.path().join("test.ckpt");
383
384        let population: Vec<Individual<RealVector>> =
385            vec![Individual::new(RealVector::new(vec![1.0, 2.0, 3.0]))];
386        let checkpoint = Checkpoint::new(5, population)
387            .with_evaluations(500)
388            .with_metadata("test", "value");
389
390        save_checkpoint(&checkpoint, &path, CheckpointFormat::Binary).unwrap();
391        let loaded: Checkpoint<RealVector> = load_checkpoint(&path).unwrap();
392
393        assert_eq!(loaded.generation, 5);
394        assert_eq!(loaded.evaluations, 500);
395        assert_eq!(loaded.metadata.get("test"), Some(&"value".to_string()));
396    }
397
398    #[test]
399    fn test_save_load_compressed() {
400        let dir = tempdir().unwrap();
401        let path = dir.path().join("test_compressed.ckpt");
402
403        // Create a larger checkpoint with repetitive data
404        let population: Vec<Individual<RealVector>> = (0..100)
405            .map(|i| Individual::new(RealVector::new(vec![i as f64; 10])))
406            .collect();
407        let checkpoint = Checkpoint::new(100, population);
408
409        save_checkpoint(&checkpoint, &path, CheckpointFormat::CompressedBinary).unwrap();
410        let loaded: Checkpoint<RealVector> = load_checkpoint(&path).unwrap();
411
412        assert_eq!(loaded.generation, 100);
413        assert_eq!(loaded.population.len(), 100);
414    }
415
416    #[test]
417    fn test_compression_decompression() {
418        let original = vec![0u8, 0, 0, 0, 0, 1, 2, 3, 3, 3, 3, 3, 3, 4, 5];
419        let compressed = compress_data(&original);
420        let decompressed = decompress_data(&compressed).unwrap();
421        assert_eq!(original, decompressed);
422    }
423
424    #[test]
425    fn test_checkpoint_manager() {
426        let dir = tempdir().unwrap();
427        let mut manager = CheckpointManager::new(dir.path(), "evolution")
428            .with_format(CheckpointFormat::Binary)
429            .keep(2)
430            .every(10);
431
432        // Save multiple checkpoints
433        for gen in [10, 20, 30, 40] {
434            let population: Vec<Individual<RealVector>> =
435                vec![Individual::new(RealVector::new(vec![gen as f64]))];
436            let checkpoint = Checkpoint::new(gen, population);
437            manager.save(&checkpoint).unwrap();
438        }
439
440        // Load latest
441        let loaded: Option<Checkpoint<RealVector>> = manager.load_latest().unwrap();
442        assert!(loaded.is_some());
443        assert_eq!(loaded.unwrap().generation, 40);
444    }
445
446    #[test]
447    fn test_version_check() {
448        let population: Vec<Individual<RealVector>> = vec![];
449        let checkpoint = Checkpoint::new(0, population);
450        assert!(checkpoint.is_compatible());
451    }
452}