Skip to main content

fugue_evo/interactive/
session.rs

1//! Session state management for interactive evolution
2//!
3//! This module provides serializable session state that allows pausing
4//! and resuming interactive evolution sessions.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[cfg(feature = "checkpoint")]
10use std::fs::File;
11#[cfg(feature = "checkpoint")]
12use std::io::{BufReader, BufWriter};
13#[cfg(feature = "checkpoint")]
14use std::path::Path;
15
16use super::aggregation::FitnessAggregator;
17use super::evaluator::{Candidate, CandidateId, EvaluationRequest};
18use super::uncertainty::FitnessEstimate;
19use crate::error::CheckpointError;
20use crate::genome::traits::EvolutionaryGenome;
21
22/// Current session format version
23pub const SESSION_VERSION: u32 = 1;
24
25/// Statistics about evaluation coverage in a session
26#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct CoverageStats {
28    /// Fraction of population with at least one evaluation (0.0 to 1.0)
29    pub coverage: f64,
30    /// Average evaluations per candidate
31    pub avg_evaluations: f64,
32    /// Minimum evaluations for any candidate
33    pub min_evaluations: usize,
34    /// Maximum evaluations for any candidate
35    pub max_evaluations: usize,
36    /// Number of candidates with zero evaluations
37    pub unevaluated_count: usize,
38    /// Total population size
39    pub population_size: usize,
40}
41
42impl CoverageStats {
43    /// Check if coverage meets minimum threshold
44    pub fn meets_threshold(&self, min_coverage: f64) -> bool {
45        self.coverage >= min_coverage
46    }
47}
48
49/// Complete state of an interactive evolution session
50///
51/// This struct captures all state needed to pause and resume an
52/// interactive evolution session, including population, fitness
53/// aggregator state, and session metadata.
54#[derive(Clone, Debug, Serialize, Deserialize)]
55#[serde(bound = "G: Serialize + for<'a> Deserialize<'a>")]
56pub struct InteractiveSession<G>
57where
58    G: EvolutionaryGenome,
59{
60    /// Schema version for forward compatibility
61    pub version: u32,
62    /// Current population with fitness estimates
63    pub population: Vec<Candidate<G>>,
64    /// Current generation number
65    pub generation: usize,
66    /// Total evaluation requests made
67    pub evaluations_requested: usize,
68    /// Total responses received (excluding skips)
69    pub responses_received: usize,
70    /// Number of skipped evaluations
71    pub skipped: usize,
72    /// Fitness aggregator state
73    pub aggregator: FitnessAggregator,
74    /// History of evaluation requests (limited to recent history)
75    pub request_history: Vec<SerializedRequest>,
76    /// Custom session metadata
77    pub metadata: HashMap<String, String>,
78    /// Next candidate ID to assign
79    pub next_candidate_id: usize,
80}
81
82/// Serialized form of an evaluation request (without genome data)
83#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct SerializedRequest {
85    /// Type of request
86    pub request_type: String,
87    /// Candidate IDs involved
88    pub candidate_ids: Vec<CandidateId>,
89    /// Generation when request was made
90    pub generation: usize,
91    /// Whether this request was skipped
92    pub was_skipped: bool,
93}
94
95impl<G> InteractiveSession<G>
96where
97    G: EvolutionaryGenome,
98{
99    /// Create a new empty session
100    pub fn new(aggregator: FitnessAggregator) -> Self {
101        Self {
102            version: SESSION_VERSION,
103            population: Vec::new(),
104            generation: 0,
105            evaluations_requested: 0,
106            responses_received: 0,
107            skipped: 0,
108            aggregator,
109            request_history: Vec::new(),
110            metadata: HashMap::new(),
111            next_candidate_id: 0,
112        }
113    }
114
115    /// Create a new session with initial population
116    pub fn with_population(population: Vec<Candidate<G>>, aggregator: FitnessAggregator) -> Self {
117        let next_id = population.iter().map(|c| c.id.0).max().unwrap_or(0) + 1;
118        Self {
119            version: SESSION_VERSION,
120            population,
121            generation: 0,
122            evaluations_requested: 0,
123            responses_received: 0,
124            skipped: 0,
125            aggregator,
126            request_history: Vec::new(),
127            metadata: HashMap::new(),
128            next_candidate_id: next_id,
129        }
130    }
131
132    /// Get the next candidate ID and increment counter
133    pub fn next_id(&mut self) -> CandidateId {
134        let id = CandidateId(self.next_candidate_id);
135        self.next_candidate_id += 1;
136        id
137    }
138
139    /// Add a candidate to the population
140    pub fn add_candidate(&mut self, genome: G) -> CandidateId {
141        let id = self.next_id();
142        let candidate = Candidate::with_generation(id, genome, self.generation);
143        self.population.push(candidate);
144        id
145    }
146
147    /// Get a candidate by ID
148    pub fn get_candidate(&self, id: CandidateId) -> Option<&Candidate<G>> {
149        self.population.iter().find(|c| c.id == id)
150    }
151
152    /// Get a mutable reference to a candidate by ID
153    pub fn get_candidate_mut(&mut self, id: CandidateId) -> Option<&mut Candidate<G>> {
154        self.population.iter_mut().find(|c| c.id == id)
155    }
156
157    /// Get all candidates that haven't been evaluated
158    pub fn unevaluated_candidates(&self) -> Vec<&Candidate<G>> {
159        self.population
160            .iter()
161            .filter(|c| !c.is_evaluated())
162            .collect()
163    }
164
165    /// Get candidates sorted by fitness (best first)
166    pub fn ranked_candidates(&self) -> Vec<&Candidate<G>> {
167        let mut candidates: Vec<_> = self.population.iter().collect();
168        candidates.sort_by(|a, b| {
169            let fa = a.fitness_estimate.unwrap_or(f64::NEG_INFINITY);
170            let fb = b.fitness_estimate.unwrap_or(f64::NEG_INFINITY);
171            fb.partial_cmp(&fa).unwrap_or(std::cmp::Ordering::Equal)
172        });
173        candidates
174    }
175
176    /// Get the best candidate
177    pub fn best_candidate(&self) -> Option<&Candidate<G>> {
178        self.population
179            .iter()
180            .filter(|c| c.fitness_estimate.is_some())
181            .max_by(|a, b| {
182                let fa = a.fitness_estimate.unwrap();
183                let fb = b.fitness_estimate.unwrap();
184                fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal)
185            })
186    }
187
188    /// Calculate coverage statistics
189    pub fn coverage_stats(&self) -> CoverageStats {
190        if self.population.is_empty() {
191            return CoverageStats::default();
192        }
193
194        let eval_counts: Vec<usize> = self.population.iter().map(|c| c.evaluation_count).collect();
195
196        let evaluated = eval_counts.iter().filter(|&&c| c > 0).count();
197        let total_evals: usize = eval_counts.iter().sum();
198
199        CoverageStats {
200            coverage: evaluated as f64 / self.population.len() as f64,
201            avg_evaluations: total_evals as f64 / self.population.len() as f64,
202            min_evaluations: eval_counts.iter().copied().min().unwrap_or(0),
203            max_evaluations: eval_counts.iter().copied().max().unwrap_or(0),
204            unevaluated_count: self.population.len() - evaluated,
205            population_size: self.population.len(),
206        }
207    }
208
209    /// Record that an evaluation request was made
210    pub fn record_request<GG: EvolutionaryGenome>(&mut self, request: &EvaluationRequest<GG>) {
211        self.evaluations_requested += 1;
212
213        let serialized = SerializedRequest {
214            request_type: match request {
215                EvaluationRequest::RateCandidates { .. } => "rating".to_string(),
216                EvaluationRequest::PairwiseComparison { .. } => "pairwise".to_string(),
217                EvaluationRequest::BatchSelection { .. } => "batch".to_string(),
218            },
219            candidate_ids: request.candidate_ids(),
220            generation: self.generation,
221            was_skipped: false,
222        };
223
224        // Keep limited history
225        const MAX_HISTORY: usize = 1000;
226        if self.request_history.len() >= MAX_HISTORY {
227            self.request_history.remove(0);
228        }
229        self.request_history.push(serialized);
230    }
231
232    /// Record that a response was received
233    pub fn record_response(&mut self, was_skipped: bool) {
234        if was_skipped {
235            self.skipped += 1;
236            if let Some(last) = self.request_history.last_mut() {
237                last.was_skipped = true;
238            }
239        } else {
240            self.responses_received += 1;
241        }
242    }
243
244    /// Advance to the next generation
245    pub fn advance_generation(&mut self) {
246        self.generation += 1;
247        self.aggregator.set_generation(self.generation);
248    }
249
250    /// Update fitness estimate for a candidate
251    pub fn update_fitness(&mut self, id: CandidateId, fitness: f64) {
252        if let Some(candidate) = self.get_candidate_mut(id) {
253            candidate.set_fitness(fitness);
254            candidate.record_evaluation();
255        }
256    }
257
258    /// Update fitness with full uncertainty information
259    pub fn update_fitness_with_uncertainty(&mut self, id: CandidateId, estimate: FitnessEstimate) {
260        if let Some(candidate) = self.get_candidate_mut(id) {
261            candidate.set_fitness_with_uncertainty(estimate);
262            candidate.record_evaluation();
263        }
264    }
265
266    /// Sync candidate fitness estimates from the aggregator
267    ///
268    /// Updates all candidates with their current fitness estimates including uncertainty.
269    /// Call this after processing responses to ensure candidates have up-to-date estimates.
270    pub fn sync_fitness_estimates(&mut self) {
271        for candidate in &mut self.population {
272            if let Some(estimate) = self.aggregator.get_fitness_estimate(&candidate.id) {
273                candidate.fitness_estimate = Some(estimate.mean);
274                candidate.fitness_with_uncertainty = Some(estimate);
275            }
276        }
277    }
278
279    /// Get fitness estimates with uncertainty for all candidates
280    ///
281    /// Returns a vector of (CandidateId, FitnessEstimate) pairs.
282    pub fn all_fitness_estimates(&self) -> Vec<(CandidateId, FitnessEstimate)> {
283        self.population
284            .iter()
285            .filter_map(|c| {
286                self.aggregator
287                    .get_fitness_estimate(&c.id)
288                    .map(|e| (c.id, e))
289            })
290            .collect()
291    }
292
293    /// Get candidates sorted by uncertainty (most uncertain first)
294    ///
295    /// Useful for identifying which candidates need more evaluation.
296    pub fn candidates_by_uncertainty(&self) -> Vec<&Candidate<G>> {
297        let mut candidates: Vec<_> = self.population.iter().collect();
298        candidates.sort_by(|a, b| {
299            let var_a = self
300                .aggregator
301                .get_fitness_estimate(&a.id)
302                .map(|e| e.variance)
303                .unwrap_or(f64::INFINITY);
304            let var_b = self
305                .aggregator
306                .get_fitness_estimate(&b.id)
307                .map(|e| e.variance)
308                .unwrap_or(f64::INFINITY);
309            // Sort descending - most uncertain first
310            var_b
311                .partial_cmp(&var_a)
312                .unwrap_or(std::cmp::Ordering::Equal)
313        });
314        candidates
315    }
316
317    /// Get the average uncertainty across all candidates
318    pub fn average_uncertainty(&self) -> f64 {
319        let estimates: Vec<_> = self
320            .population
321            .iter()
322            .filter_map(|c| self.aggregator.get_fitness_estimate(&c.id))
323            .collect();
324
325        if estimates.is_empty() {
326            return f64::INFINITY;
327        }
328
329        let total_variance: f64 = estimates
330            .iter()
331            .map(|e| {
332                if e.variance.is_finite() {
333                    e.variance
334                } else {
335                    1e6 // Large but finite for averaging
336                }
337            })
338            .sum();
339
340        total_variance / estimates.len() as f64
341    }
342
343    /// Replace the population with new candidates
344    pub fn replace_population(&mut self, new_population: Vec<Candidate<G>>) {
345        let max_id = new_population.iter().map(|c| c.id.0).max().unwrap_or(0);
346        self.next_candidate_id = max_id + 1;
347        self.population = new_population;
348    }
349
350    /// Add metadata to the session
351    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
352        self.metadata.insert(key.into(), value.into());
353    }
354
355    /// Get metadata value
356    pub fn get_metadata(&self, key: &str) -> Option<&String> {
357        self.metadata.get(key)
358    }
359
360    /// Get response rate (responses / requests)
361    pub fn response_rate(&self) -> f64 {
362        if self.evaluations_requested > 0 {
363            self.responses_received as f64 / self.evaluations_requested as f64
364        } else {
365            0.0
366        }
367    }
368
369    /// Get skip rate (skips / requests)
370    pub fn skip_rate(&self) -> f64 {
371        if self.evaluations_requested > 0 {
372            self.skipped as f64 / self.evaluations_requested as f64
373        } else {
374            0.0
375        }
376    }
377}
378
379/// File-based session persistence (requires `checkpoint` feature)
380#[cfg(feature = "checkpoint")]
381impl<G> InteractiveSession<G>
382where
383    G: EvolutionaryGenome + Serialize + for<'de> Deserialize<'de>,
384{
385    /// Save session to a file
386    pub fn save(&self, path: &Path) -> Result<(), CheckpointError> {
387        let file = File::create(path)?;
388        let writer = BufWriter::new(file);
389        serde_json::to_writer_pretty(writer, self).map_err(|e| {
390            CheckpointError::Serialization(format!("Failed to serialize session: {}", e))
391        })?;
392        Ok(())
393    }
394
395    /// Load session from a file
396    pub fn load(path: &Path) -> Result<Self, CheckpointError> {
397        let file = File::open(path)?;
398        let reader = BufReader::new(file);
399        let session: Self = serde_json::from_reader(reader).map_err(|e| {
400            CheckpointError::Deserialization(format!("Failed to deserialize session: {}", e))
401        })?;
402
403        // Check version compatibility
404        if session.version > SESSION_VERSION {
405            return Err(CheckpointError::VersionTooNew(session.version));
406        }
407
408        Ok(session)
409    }
410}
411
412impl<G> InteractiveSession<G>
413where
414    G: EvolutionaryGenome + Serialize + for<'de> Deserialize<'de>,
415{
416    /// Serialize session to JSON string (WASM-compatible)
417    pub fn to_json(&self) -> Result<String, CheckpointError> {
418        serde_json::to_string_pretty(self).map_err(|e| {
419            CheckpointError::Serialization(format!("Failed to serialize session: {}", e))
420        })
421    }
422
423    /// Deserialize session from JSON string (WASM-compatible)
424    pub fn from_json(json: &str) -> Result<Self, CheckpointError> {
425        let session: Self = serde_json::from_str(json).map_err(|e| {
426            CheckpointError::Deserialization(format!("Failed to deserialize session: {}", e))
427        })?;
428
429        // Check version compatibility
430        if session.version > SESSION_VERSION {
431            return Err(CheckpointError::VersionTooNew(session.version));
432        }
433
434        Ok(session)
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::genome::real_vector::RealVector;
442    use crate::interactive::aggregation::AggregationModel;
443
444    #[test]
445    fn test_session_creation() {
446        let aggregator = FitnessAggregator::new(AggregationModel::default());
447        let session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
448
449        assert_eq!(session.generation, 0);
450        assert!(session.population.is_empty());
451        assert_eq!(session.evaluations_requested, 0);
452    }
453
454    #[test]
455    fn test_add_candidate() {
456        let aggregator = FitnessAggregator::new(AggregationModel::default());
457        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
458
459        let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
460        let id = session.add_candidate(genome);
461
462        assert_eq!(id, CandidateId(0));
463        assert_eq!(session.population.len(), 1);
464        assert_eq!(session.get_candidate(id).unwrap().birth_generation, 0);
465    }
466
467    #[test]
468    fn test_coverage_stats() {
469        let aggregator = FitnessAggregator::new(AggregationModel::default());
470        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
471
472        // Add 4 candidates
473        for i in 0..4 {
474            session.add_candidate(RealVector::new(vec![i as f64]));
475        }
476
477        // Evaluate 2 of them
478        session.population[0].record_evaluation();
479        session.population[1].record_evaluation();
480        session.population[1].record_evaluation(); // Evaluate twice
481
482        let stats = session.coverage_stats();
483
484        assert_eq!(stats.population_size, 4);
485        assert_eq!(stats.coverage, 0.5);
486        assert_eq!(stats.unevaluated_count, 2);
487        assert_eq!(stats.min_evaluations, 0);
488        assert_eq!(stats.max_evaluations, 2);
489    }
490
491    #[test]
492    fn test_ranked_candidates() {
493        let aggregator = FitnessAggregator::new(AggregationModel::default());
494        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
495
496        for i in 0..3 {
497            let id = session.add_candidate(RealVector::new(vec![i as f64]));
498            session.update_fitness(id, i as f64 * 10.0);
499        }
500
501        let ranked = session.ranked_candidates();
502        assert_eq!(ranked[0].fitness_estimate, Some(20.0)); // Best first
503        assert_eq!(ranked[2].fitness_estimate, Some(0.0)); // Worst last
504    }
505
506    #[test]
507    fn test_advance_generation() {
508        let aggregator = FitnessAggregator::new(AggregationModel::default());
509        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
510
511        session.advance_generation();
512        assert_eq!(session.generation, 1);
513
514        let id = session.add_candidate(RealVector::new(vec![1.0]));
515        assert_eq!(session.get_candidate(id).unwrap().birth_generation, 1);
516    }
517
518    #[test]
519    fn test_response_tracking() {
520        let aggregator = FitnessAggregator::new(AggregationModel::default());
521        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
522
523        let c1: Candidate<RealVector> = Candidate::new(CandidateId(0), RealVector::new(vec![1.0]));
524        let request = EvaluationRequest::rate(vec![c1]);
525        session.record_request(&request);
526        session.record_response(false);
527
528        session.record_request(&request);
529        session.record_response(true); // Skip
530
531        assert_eq!(session.evaluations_requested, 2);
532        assert_eq!(session.responses_received, 1);
533        assert_eq!(session.skipped, 1);
534        assert_eq!(session.response_rate(), 0.5);
535        assert_eq!(session.skip_rate(), 0.5);
536    }
537
538    #[test]
539    fn test_metadata() {
540        let aggregator = FitnessAggregator::new(AggregationModel::default());
541        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
542
543        session.set_metadata("experiment", "test_run");
544        session.set_metadata("user", "alice");
545
546        assert_eq!(
547            session.get_metadata("experiment"),
548            Some(&"test_run".to_string())
549        );
550        assert_eq!(session.get_metadata("user"), Some(&"alice".to_string()));
551        assert_eq!(session.get_metadata("missing"), None);
552    }
553
554    #[test]
555    fn test_session_serialization() {
556        let aggregator = FitnessAggregator::new(AggregationModel::DirectRating {
557            default_rating: 5.0,
558        });
559        let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
560
561        session.add_candidate(RealVector::new(vec![1.0, 2.0]));
562        session.add_candidate(RealVector::new(vec![3.0, 4.0]));
563        session.set_metadata("test", "value");
564
565        // Serialize to JSON
566        let json = serde_json::to_string(&session).expect("Failed to serialize");
567
568        // Deserialize back
569        let loaded: InteractiveSession<RealVector> =
570            serde_json::from_str(&json).expect("Failed to deserialize");
571
572        assert_eq!(loaded.population.len(), 2);
573        assert_eq!(loaded.get_metadata("test"), Some(&"value".to_string()));
574    }
575}