1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[cfg(feature = "checkpoint")]
10use std::fs::File;
11#[cfg(feature = "checkpoint")]
12use std::io::{BufReader, BufWriter};
13#[cfg(feature = "checkpoint")]
14use std::path::Path;
15
16use super::aggregation::FitnessAggregator;
17use super::evaluator::{Candidate, CandidateId, EvaluationRequest};
18use super::uncertainty::FitnessEstimate;
19use crate::error::CheckpointError;
20use crate::genome::traits::EvolutionaryGenome;
21
22pub const SESSION_VERSION: u32 = 1;
24
25#[derive(Clone, Debug, Default, Serialize, Deserialize)]
27pub struct CoverageStats {
28 pub coverage: f64,
30 pub avg_evaluations: f64,
32 pub min_evaluations: usize,
34 pub max_evaluations: usize,
36 pub unevaluated_count: usize,
38 pub population_size: usize,
40}
41
42impl CoverageStats {
43 pub fn meets_threshold(&self, min_coverage: f64) -> bool {
45 self.coverage >= min_coverage
46 }
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
55#[serde(bound = "G: Serialize + for<'a> Deserialize<'a>")]
56pub struct InteractiveSession<G>
57where
58 G: EvolutionaryGenome,
59{
60 pub version: u32,
62 pub population: Vec<Candidate<G>>,
64 pub generation: usize,
66 pub evaluations_requested: usize,
68 pub responses_received: usize,
70 pub skipped: usize,
72 pub aggregator: FitnessAggregator,
74 pub request_history: Vec<SerializedRequest>,
76 pub metadata: HashMap<String, String>,
78 pub next_candidate_id: usize,
80}
81
82#[derive(Clone, Debug, Serialize, Deserialize)]
84pub struct SerializedRequest {
85 pub request_type: String,
87 pub candidate_ids: Vec<CandidateId>,
89 pub generation: usize,
91 pub was_skipped: bool,
93}
94
95impl<G> InteractiveSession<G>
96where
97 G: EvolutionaryGenome,
98{
99 pub fn new(aggregator: FitnessAggregator) -> Self {
101 Self {
102 version: SESSION_VERSION,
103 population: Vec::new(),
104 generation: 0,
105 evaluations_requested: 0,
106 responses_received: 0,
107 skipped: 0,
108 aggregator,
109 request_history: Vec::new(),
110 metadata: HashMap::new(),
111 next_candidate_id: 0,
112 }
113 }
114
115 pub fn with_population(population: Vec<Candidate<G>>, aggregator: FitnessAggregator) -> Self {
117 let next_id = population.iter().map(|c| c.id.0).max().unwrap_or(0) + 1;
118 Self {
119 version: SESSION_VERSION,
120 population,
121 generation: 0,
122 evaluations_requested: 0,
123 responses_received: 0,
124 skipped: 0,
125 aggregator,
126 request_history: Vec::new(),
127 metadata: HashMap::new(),
128 next_candidate_id: next_id,
129 }
130 }
131
132 pub fn next_id(&mut self) -> CandidateId {
134 let id = CandidateId(self.next_candidate_id);
135 self.next_candidate_id += 1;
136 id
137 }
138
139 pub fn add_candidate(&mut self, genome: G) -> CandidateId {
141 let id = self.next_id();
142 let candidate = Candidate::with_generation(id, genome, self.generation);
143 self.population.push(candidate);
144 id
145 }
146
147 pub fn get_candidate(&self, id: CandidateId) -> Option<&Candidate<G>> {
149 self.population.iter().find(|c| c.id == id)
150 }
151
152 pub fn get_candidate_mut(&mut self, id: CandidateId) -> Option<&mut Candidate<G>> {
154 self.population.iter_mut().find(|c| c.id == id)
155 }
156
157 pub fn unevaluated_candidates(&self) -> Vec<&Candidate<G>> {
159 self.population
160 .iter()
161 .filter(|c| !c.is_evaluated())
162 .collect()
163 }
164
165 pub fn ranked_candidates(&self) -> Vec<&Candidate<G>> {
167 let mut candidates: Vec<_> = self.population.iter().collect();
168 candidates.sort_by(|a, b| {
169 let fa = a.fitness_estimate.unwrap_or(f64::NEG_INFINITY);
170 let fb = b.fitness_estimate.unwrap_or(f64::NEG_INFINITY);
171 fb.partial_cmp(&fa).unwrap_or(std::cmp::Ordering::Equal)
172 });
173 candidates
174 }
175
176 pub fn best_candidate(&self) -> Option<&Candidate<G>> {
178 self.population
179 .iter()
180 .filter(|c| c.fitness_estimate.is_some())
181 .max_by(|a, b| {
182 let fa = a.fitness_estimate.unwrap();
183 let fb = b.fitness_estimate.unwrap();
184 fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal)
185 })
186 }
187
188 pub fn coverage_stats(&self) -> CoverageStats {
190 if self.population.is_empty() {
191 return CoverageStats::default();
192 }
193
194 let eval_counts: Vec<usize> = self.population.iter().map(|c| c.evaluation_count).collect();
195
196 let evaluated = eval_counts.iter().filter(|&&c| c > 0).count();
197 let total_evals: usize = eval_counts.iter().sum();
198
199 CoverageStats {
200 coverage: evaluated as f64 / self.population.len() as f64,
201 avg_evaluations: total_evals as f64 / self.population.len() as f64,
202 min_evaluations: eval_counts.iter().copied().min().unwrap_or(0),
203 max_evaluations: eval_counts.iter().copied().max().unwrap_or(0),
204 unevaluated_count: self.population.len() - evaluated,
205 population_size: self.population.len(),
206 }
207 }
208
209 pub fn record_request<GG: EvolutionaryGenome>(&mut self, request: &EvaluationRequest<GG>) {
211 self.evaluations_requested += 1;
212
213 let serialized = SerializedRequest {
214 request_type: match request {
215 EvaluationRequest::RateCandidates { .. } => "rating".to_string(),
216 EvaluationRequest::PairwiseComparison { .. } => "pairwise".to_string(),
217 EvaluationRequest::BatchSelection { .. } => "batch".to_string(),
218 },
219 candidate_ids: request.candidate_ids(),
220 generation: self.generation,
221 was_skipped: false,
222 };
223
224 const MAX_HISTORY: usize = 1000;
226 if self.request_history.len() >= MAX_HISTORY {
227 self.request_history.remove(0);
228 }
229 self.request_history.push(serialized);
230 }
231
232 pub fn record_response(&mut self, was_skipped: bool) {
234 if was_skipped {
235 self.skipped += 1;
236 if let Some(last) = self.request_history.last_mut() {
237 last.was_skipped = true;
238 }
239 } else {
240 self.responses_received += 1;
241 }
242 }
243
244 pub fn advance_generation(&mut self) {
246 self.generation += 1;
247 self.aggregator.set_generation(self.generation);
248 }
249
250 pub fn update_fitness(&mut self, id: CandidateId, fitness: f64) {
252 if let Some(candidate) = self.get_candidate_mut(id) {
253 candidate.set_fitness(fitness);
254 candidate.record_evaluation();
255 }
256 }
257
258 pub fn update_fitness_with_uncertainty(&mut self, id: CandidateId, estimate: FitnessEstimate) {
260 if let Some(candidate) = self.get_candidate_mut(id) {
261 candidate.set_fitness_with_uncertainty(estimate);
262 candidate.record_evaluation();
263 }
264 }
265
266 pub fn sync_fitness_estimates(&mut self) {
271 for candidate in &mut self.population {
272 if let Some(estimate) = self.aggregator.get_fitness_estimate(&candidate.id) {
273 candidate.fitness_estimate = Some(estimate.mean);
274 candidate.fitness_with_uncertainty = Some(estimate);
275 }
276 }
277 }
278
279 pub fn all_fitness_estimates(&self) -> Vec<(CandidateId, FitnessEstimate)> {
283 self.population
284 .iter()
285 .filter_map(|c| {
286 self.aggregator
287 .get_fitness_estimate(&c.id)
288 .map(|e| (c.id, e))
289 })
290 .collect()
291 }
292
293 pub fn candidates_by_uncertainty(&self) -> Vec<&Candidate<G>> {
297 let mut candidates: Vec<_> = self.population.iter().collect();
298 candidates.sort_by(|a, b| {
299 let var_a = self
300 .aggregator
301 .get_fitness_estimate(&a.id)
302 .map(|e| e.variance)
303 .unwrap_or(f64::INFINITY);
304 let var_b = self
305 .aggregator
306 .get_fitness_estimate(&b.id)
307 .map(|e| e.variance)
308 .unwrap_or(f64::INFINITY);
309 var_b
311 .partial_cmp(&var_a)
312 .unwrap_or(std::cmp::Ordering::Equal)
313 });
314 candidates
315 }
316
317 pub fn average_uncertainty(&self) -> f64 {
319 let estimates: Vec<_> = self
320 .population
321 .iter()
322 .filter_map(|c| self.aggregator.get_fitness_estimate(&c.id))
323 .collect();
324
325 if estimates.is_empty() {
326 return f64::INFINITY;
327 }
328
329 let total_variance: f64 = estimates
330 .iter()
331 .map(|e| {
332 if e.variance.is_finite() {
333 e.variance
334 } else {
335 1e6 }
337 })
338 .sum();
339
340 total_variance / estimates.len() as f64
341 }
342
343 pub fn replace_population(&mut self, new_population: Vec<Candidate<G>>) {
345 let max_id = new_population.iter().map(|c| c.id.0).max().unwrap_or(0);
346 self.next_candidate_id = max_id + 1;
347 self.population = new_population;
348 }
349
350 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
352 self.metadata.insert(key.into(), value.into());
353 }
354
355 pub fn get_metadata(&self, key: &str) -> Option<&String> {
357 self.metadata.get(key)
358 }
359
360 pub fn response_rate(&self) -> f64 {
362 if self.evaluations_requested > 0 {
363 self.responses_received as f64 / self.evaluations_requested as f64
364 } else {
365 0.0
366 }
367 }
368
369 pub fn skip_rate(&self) -> f64 {
371 if self.evaluations_requested > 0 {
372 self.skipped as f64 / self.evaluations_requested as f64
373 } else {
374 0.0
375 }
376 }
377}
378
379#[cfg(feature = "checkpoint")]
381impl<G> InteractiveSession<G>
382where
383 G: EvolutionaryGenome + Serialize + for<'de> Deserialize<'de>,
384{
385 pub fn save(&self, path: &Path) -> Result<(), CheckpointError> {
387 let file = File::create(path)?;
388 let writer = BufWriter::new(file);
389 serde_json::to_writer_pretty(writer, self).map_err(|e| {
390 CheckpointError::Serialization(format!("Failed to serialize session: {}", e))
391 })?;
392 Ok(())
393 }
394
395 pub fn load(path: &Path) -> Result<Self, CheckpointError> {
397 let file = File::open(path)?;
398 let reader = BufReader::new(file);
399 let session: Self = serde_json::from_reader(reader).map_err(|e| {
400 CheckpointError::Deserialization(format!("Failed to deserialize session: {}", e))
401 })?;
402
403 if session.version > SESSION_VERSION {
405 return Err(CheckpointError::VersionTooNew(session.version));
406 }
407
408 Ok(session)
409 }
410}
411
412impl<G> InteractiveSession<G>
413where
414 G: EvolutionaryGenome + Serialize + for<'de> Deserialize<'de>,
415{
416 pub fn to_json(&self) -> Result<String, CheckpointError> {
418 serde_json::to_string_pretty(self).map_err(|e| {
419 CheckpointError::Serialization(format!("Failed to serialize session: {}", e))
420 })
421 }
422
423 pub fn from_json(json: &str) -> Result<Self, CheckpointError> {
425 let session: Self = serde_json::from_str(json).map_err(|e| {
426 CheckpointError::Deserialization(format!("Failed to deserialize session: {}", e))
427 })?;
428
429 if session.version > SESSION_VERSION {
431 return Err(CheckpointError::VersionTooNew(session.version));
432 }
433
434 Ok(session)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::genome::real_vector::RealVector;
442 use crate::interactive::aggregation::AggregationModel;
443
444 #[test]
445 fn test_session_creation() {
446 let aggregator = FitnessAggregator::new(AggregationModel::default());
447 let session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
448
449 assert_eq!(session.generation, 0);
450 assert!(session.population.is_empty());
451 assert_eq!(session.evaluations_requested, 0);
452 }
453
454 #[test]
455 fn test_add_candidate() {
456 let aggregator = FitnessAggregator::new(AggregationModel::default());
457 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
458
459 let genome = RealVector::new(vec![1.0, 2.0, 3.0]);
460 let id = session.add_candidate(genome);
461
462 assert_eq!(id, CandidateId(0));
463 assert_eq!(session.population.len(), 1);
464 assert_eq!(session.get_candidate(id).unwrap().birth_generation, 0);
465 }
466
467 #[test]
468 fn test_coverage_stats() {
469 let aggregator = FitnessAggregator::new(AggregationModel::default());
470 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
471
472 for i in 0..4 {
474 session.add_candidate(RealVector::new(vec![i as f64]));
475 }
476
477 session.population[0].record_evaluation();
479 session.population[1].record_evaluation();
480 session.population[1].record_evaluation(); let stats = session.coverage_stats();
483
484 assert_eq!(stats.population_size, 4);
485 assert_eq!(stats.coverage, 0.5);
486 assert_eq!(stats.unevaluated_count, 2);
487 assert_eq!(stats.min_evaluations, 0);
488 assert_eq!(stats.max_evaluations, 2);
489 }
490
491 #[test]
492 fn test_ranked_candidates() {
493 let aggregator = FitnessAggregator::new(AggregationModel::default());
494 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
495
496 for i in 0..3 {
497 let id = session.add_candidate(RealVector::new(vec![i as f64]));
498 session.update_fitness(id, i as f64 * 10.0);
499 }
500
501 let ranked = session.ranked_candidates();
502 assert_eq!(ranked[0].fitness_estimate, Some(20.0)); assert_eq!(ranked[2].fitness_estimate, Some(0.0)); }
505
506 #[test]
507 fn test_advance_generation() {
508 let aggregator = FitnessAggregator::new(AggregationModel::default());
509 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
510
511 session.advance_generation();
512 assert_eq!(session.generation, 1);
513
514 let id = session.add_candidate(RealVector::new(vec![1.0]));
515 assert_eq!(session.get_candidate(id).unwrap().birth_generation, 1);
516 }
517
518 #[test]
519 fn test_response_tracking() {
520 let aggregator = FitnessAggregator::new(AggregationModel::default());
521 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
522
523 let c1: Candidate<RealVector> = Candidate::new(CandidateId(0), RealVector::new(vec![1.0]));
524 let request = EvaluationRequest::rate(vec![c1]);
525 session.record_request(&request);
526 session.record_response(false);
527
528 session.record_request(&request);
529 session.record_response(true); assert_eq!(session.evaluations_requested, 2);
532 assert_eq!(session.responses_received, 1);
533 assert_eq!(session.skipped, 1);
534 assert_eq!(session.response_rate(), 0.5);
535 assert_eq!(session.skip_rate(), 0.5);
536 }
537
538 #[test]
539 fn test_metadata() {
540 let aggregator = FitnessAggregator::new(AggregationModel::default());
541 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
542
543 session.set_metadata("experiment", "test_run");
544 session.set_metadata("user", "alice");
545
546 assert_eq!(
547 session.get_metadata("experiment"),
548 Some(&"test_run".to_string())
549 );
550 assert_eq!(session.get_metadata("user"), Some(&"alice".to_string()));
551 assert_eq!(session.get_metadata("missing"), None);
552 }
553
554 #[test]
555 fn test_session_serialization() {
556 let aggregator = FitnessAggregator::new(AggregationModel::DirectRating {
557 default_rating: 5.0,
558 });
559 let mut session: InteractiveSession<RealVector> = InteractiveSession::new(aggregator);
560
561 session.add_candidate(RealVector::new(vec![1.0, 2.0]));
562 session.add_candidate(RealVector::new(vec![3.0, 4.0]));
563 session.set_metadata("test", "value");
564
565 let json = serde_json::to_string(&session).expect("Failed to serialize");
567
568 let loaded: InteractiveSession<RealVector> =
570 serde_json::from_str(&json).expect("Failed to deserialize");
571
572 assert_eq!(loaded.population.len(), 2);
573 assert_eq!(loaded.get_metadata("test"), Some(&"value".to_string()));
574 }
575}