Skip to main content

fugue_evo/interactive/
aggregation.rs

1//! Fitness aggregation models for interactive evaluation
2//!
3//! This module provides various statistical models for converting user feedback
4//! (ratings, comparisons, selections) into fitness values suitable for evolution.
5//!
6//! # Available Models
7//!
8//! - **DirectRating**: Simple average of user ratings
9//! - **Elo**: Classic Elo rating system from pairwise comparisons
10//! - **BradleyTerry**: Maximum likelihood estimation for pairwise data
11//! - **ImplicitRanking**: Bonus/penalty system from batch selections
12//!
13//! # Uncertainty Quantification
14//!
15//! All models support uncertainty estimation via `get_fitness_estimate()`,
16//! which returns a `FitnessEstimate` with variance and confidence intervals.
17
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21use super::bradley_terry::{BradleyTerryModel, BradleyTerryOptimizer};
22use super::evaluator::{CandidateId, EvaluationResponse};
23use super::uncertainty::FitnessEstimate;
24
25/// Aggregation model for converting user feedback to fitness
26#[derive(Clone, Debug, Serialize, Deserialize)]
27pub enum AggregationModel {
28    /// Direct rating average
29    ///
30    /// Simply averages all ratings received for each candidate.
31    /// Uses default_rating for candidates with no ratings.
32    DirectRating {
33        /// Default rating for unevaluated candidates
34        default_rating: f64,
35    },
36
37    /// Elo rating system
38    ///
39    /// Classic chess-style rating from pairwise comparisons.
40    /// Good for transitive preference modeling.
41    Elo {
42        /// Initial rating for new candidates
43        initial_rating: f64,
44        /// K-factor controlling rating volatility
45        k_factor: f64,
46    },
47
48    /// Bradley-Terry model
49    ///
50    /// Maximum likelihood estimation for pairwise comparison data.
51    /// Provides more statistically principled estimates than Elo.
52    /// Now supports proper MLE with Newton-Raphson or MM algorithms.
53    BradleyTerry {
54        /// Initial strength parameter
55        initial_strength: f64,
56        /// Optimizer configuration (Newton-Raphson or MM)
57        #[serde(default)]
58        optimizer: BradleyTerryOptimizer,
59    },
60
61    /// Legacy Bradley-Terry model (for backward compatibility)
62    ///
63    /// Uses the simplified iterative MM approach from earlier versions.
64    #[serde(alias = "BradleyTerryLegacy")]
65    BradleyTerrySimple {
66        /// Initial strength parameter
67        initial_strength: f64,
68        /// Learning rate for iterative updates
69        learning_rate: f64,
70        /// Number of iterations
71        iterations: usize,
72    },
73
74    /// Implicit ranking from batch selections
75    ///
76    /// Assigns bonuses to selected candidates and penalties to
77    /// non-selected candidates in each batch.
78    ImplicitRanking {
79        /// Fitness bonus for being selected
80        selected_bonus: f64,
81        /// Fitness penalty for not being selected
82        not_selected_penalty: f64,
83        /// Base fitness for all candidates
84        base_fitness: f64,
85    },
86}
87
88impl Default for AggregationModel {
89    fn default() -> Self {
90        Self::DirectRating {
91            default_rating: 5.0,
92        }
93    }
94}
95
96/// Statistics tracked for each candidate
97#[derive(Clone, Debug, Default, Serialize, Deserialize)]
98pub struct CandidateStats {
99    /// Sum of all ratings received
100    pub rating_sum: f64,
101    /// Sum of squared ratings (for variance calculation)
102    #[serde(default)]
103    pub rating_sum_squares: f64,
104    /// Count of ratings received
105    pub rating_count: usize,
106    /// Current model-based score (Elo, Bradley-Terry strength, etc.)
107    pub model_score: f64,
108    /// Variance of the model score (for uncertainty quantification)
109    #[serde(default = "default_variance")]
110    pub model_variance: f64,
111    /// Number of wins in pairwise comparisons
112    pub wins: usize,
113    /// Number of losses in pairwise comparisons
114    pub losses: usize,
115    /// Number of ties in pairwise comparisons
116    pub ties: usize,
117    /// Times selected in batch selection
118    pub times_selected: usize,
119    /// Times presented but not selected
120    pub times_passed: usize,
121}
122
123fn default_variance() -> f64 {
124    f64::INFINITY
125}
126
127impl CandidateStats {
128    /// Create new stats with the given initial model score
129    pub fn new(initial_score: f64) -> Self {
130        Self {
131            model_score: initial_score,
132            model_variance: f64::INFINITY,
133            ..Default::default()
134        }
135    }
136
137    /// Get the average rating (or None if no ratings)
138    pub fn average_rating(&self) -> Option<f64> {
139        if self.rating_count > 0 {
140            Some(self.rating_sum / self.rating_count as f64)
141        } else {
142            None
143        }
144    }
145
146    /// Get the sample variance of ratings
147    pub fn rating_variance(&self) -> Option<f64> {
148        if self.rating_count < 2 {
149            return None;
150        }
151        let n = self.rating_count as f64;
152        let mean = self.rating_sum / n;
153        // Var = E[X²] - E[X]²
154        let var = (self.rating_sum_squares / n) - (mean * mean);
155        // Convert to sample variance (Bessel's correction)
156        Some(var * n / (n - 1.0))
157    }
158
159    /// Get the variance of the mean (standard error squared)
160    pub fn rating_variance_of_mean(&self) -> Option<f64> {
161        self.rating_variance()
162            .map(|var| var / self.rating_count as f64)
163    }
164
165    /// Get total number of comparisons
166    pub fn total_comparisons(&self) -> usize {
167        self.wins + self.losses + self.ties
168    }
169
170    /// Get win rate (0.0 to 1.0)
171    pub fn win_rate(&self) -> Option<f64> {
172        let total = self.total_comparisons();
173        if total > 0 {
174            Some(self.wins as f64 / total as f64)
175        } else {
176            None
177        }
178    }
179
180    /// Get selection rate (0.0 to 1.0)
181    pub fn selection_rate(&self) -> Option<f64> {
182        let total = self.times_selected + self.times_passed;
183        if total > 0 {
184            Some(self.times_selected as f64 / total as f64)
185        } else {
186            None
187        }
188    }
189}
190
191/// Record of a pairwise comparison
192#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct ComparisonRecord {
194    /// Winner's ID
195    pub winner: CandidateId,
196    /// Loser's ID
197    pub loser: CandidateId,
198    /// Generation when comparison occurred
199    pub generation: usize,
200}
201
202/// Aggregates partial/incremental feedback into fitness estimates
203#[derive(Clone, Debug, Serialize, Deserialize)]
204pub struct FitnessAggregator {
205    /// The aggregation model to use
206    model: AggregationModel,
207    /// Per-candidate statistics
208    candidate_stats: HashMap<CandidateId, CandidateStats>,
209    /// History of pairwise comparisons (for Bradley-Terry updates)
210    comparisons: Vec<ComparisonRecord>,
211    /// Current generation
212    current_generation: usize,
213}
214
215impl FitnessAggregator {
216    /// Create a new aggregator with the given model
217    pub fn new(model: AggregationModel) -> Self {
218        Self {
219            model,
220            candidate_stats: HashMap::new(),
221            comparisons: Vec::new(),
222            current_generation: 0,
223        }
224    }
225
226    /// Get the aggregation model
227    pub fn model(&self) -> &AggregationModel {
228        &self.model
229    }
230
231    /// Set the current generation
232    pub fn set_generation(&mut self, generation: usize) {
233        self.current_generation = generation;
234    }
235
236    /// Ensure a candidate has stats initialized
237    fn ensure_stats(&mut self, id: CandidateId) {
238        if !self.candidate_stats.contains_key(&id) {
239            let initial_score = match &self.model {
240                AggregationModel::DirectRating { default_rating } => *default_rating,
241                AggregationModel::Elo { initial_rating, .. } => *initial_rating,
242                AggregationModel::BradleyTerry {
243                    initial_strength, ..
244                } => *initial_strength,
245                AggregationModel::BradleyTerrySimple {
246                    initial_strength, ..
247                } => *initial_strength,
248                AggregationModel::ImplicitRanking { base_fitness, .. } => *base_fitness,
249            };
250            self.candidate_stats
251                .insert(id, CandidateStats::new(initial_score));
252        }
253    }
254
255    /// Get stats for a candidate
256    pub fn get_stats(&self, id: &CandidateId) -> Option<&CandidateStats> {
257        self.candidate_stats.get(id)
258    }
259
260    /// Get current fitness estimate for a candidate (point estimate only)
261    ///
262    /// For uncertainty information, use `get_fitness_estimate()` instead.
263    pub fn get_fitness(&self, id: &CandidateId) -> Option<f64> {
264        let stats = self.candidate_stats.get(id)?;
265
266        Some(match &self.model {
267            AggregationModel::DirectRating { default_rating } => {
268                stats.average_rating().unwrap_or(*default_rating)
269            }
270            AggregationModel::Elo { .. } => stats.model_score,
271            AggregationModel::BradleyTerry { .. } => stats.model_score,
272            AggregationModel::BradleyTerrySimple { .. } => stats.model_score,
273            AggregationModel::ImplicitRanking { .. } => {
274                // Score is base + cumulative bonuses/penalties
275                stats.model_score
276            }
277        })
278    }
279
280    /// Get fitness estimate with uncertainty quantification
281    ///
282    /// Returns a `FitnessEstimate` containing the point estimate, variance,
283    /// and confidence intervals.
284    pub fn get_fitness_estimate(&self, id: &CandidateId) -> Option<FitnessEstimate> {
285        let stats = self.candidate_stats.get(id)?;
286
287        Some(match &self.model {
288            AggregationModel::DirectRating { default_rating } => {
289                if stats.rating_count == 0 {
290                    FitnessEstimate::uninformative(*default_rating)
291                } else {
292                    let mean = stats.rating_sum / stats.rating_count as f64;
293                    let variance = stats.rating_variance_of_mean().unwrap_or(f64::INFINITY);
294                    FitnessEstimate::new(mean, variance, stats.rating_count)
295                }
296            }
297            AggregationModel::Elo { k_factor, .. } => {
298                // Elo variance approximation based on K-factor and game count
299                let n_games = stats.total_comparisons();
300                let variance = if n_games == 0 {
301                    f64::INFINITY
302                } else {
303                    // Approximate variance: decreases with games, proportional to K²
304                    let base_var = k_factor * k_factor * 0.25; // Bernoulli variance factor
305                    base_var / n_games as f64
306                };
307                FitnessEstimate::new(stats.model_score, variance, n_games)
308            }
309            AggregationModel::BradleyTerry { .. } | AggregationModel::BradleyTerrySimple { .. } => {
310                // Use stored variance from MLE computation
311                let n_comparisons = stats.total_comparisons();
312                let variance = if stats.model_variance.is_finite() {
313                    stats.model_variance
314                } else if n_comparisons == 0 {
315                    f64::INFINITY
316                } else {
317                    // Fallback: approximate variance
318                    1.0 / n_comparisons as f64
319                };
320                FitnessEstimate::new(stats.model_score, variance, n_comparisons)
321            }
322            AggregationModel::ImplicitRanking { .. } => {
323                // Binomial variance on selection rate
324                let n = stats.times_selected + stats.times_passed;
325                if n == 0 {
326                    FitnessEstimate::uninformative(stats.model_score)
327                } else {
328                    let p = stats.times_selected as f64 / n as f64;
329                    let variance = p * (1.0 - p) / n as f64;
330                    FitnessEstimate::new(stats.model_score, variance, n)
331                }
332            }
333        })
334    }
335
336    /// Get access to comparison records (for Bradley-Terry MLE)
337    pub fn comparisons(&self) -> &[ComparisonRecord] {
338        &self.comparisons
339    }
340
341    /// Record a rating for a candidate
342    pub fn record_rating(&mut self, id: CandidateId, rating: f64) {
343        self.ensure_stats(id);
344        if let Some(stats) = self.candidate_stats.get_mut(&id) {
345            stats.rating_sum += rating;
346            stats.rating_sum_squares += rating * rating;
347            stats.rating_count += 1;
348        }
349    }
350
351    /// Record a pairwise comparison result
352    pub fn record_comparison(&mut self, winner: CandidateId, loser: CandidateId) {
353        self.ensure_stats(winner);
354        self.ensure_stats(loser);
355
356        // Update stats
357        if let Some(winner_stats) = self.candidate_stats.get_mut(&winner) {
358            winner_stats.wins += 1;
359        }
360        if let Some(loser_stats) = self.candidate_stats.get_mut(&loser) {
361            loser_stats.losses += 1;
362        }
363
364        // Record comparison for history
365        self.comparisons.push(ComparisonRecord {
366            winner,
367            loser,
368            generation: self.current_generation,
369        });
370
371        // Update model scores
372        match &self.model {
373            AggregationModel::Elo { k_factor, .. } => {
374                self.update_elo(winner, loser, *k_factor);
375            }
376            AggregationModel::BradleyTerry { .. } => {
377                // Bradley-Terry updates are batched via recompute_all()
378            }
379            _ => {}
380        }
381    }
382
383    /// Record a tie in pairwise comparison
384    pub fn record_tie(&mut self, id_a: CandidateId, id_b: CandidateId) {
385        self.ensure_stats(id_a);
386        self.ensure_stats(id_b);
387
388        if let Some(stats) = self.candidate_stats.get_mut(&id_a) {
389            stats.ties += 1;
390        }
391        if let Some(stats) = self.candidate_stats.get_mut(&id_b) {
392            stats.ties += 1;
393        }
394
395        // For Elo, treat tie as half-win each
396        if let AggregationModel::Elo { k_factor, .. } = &self.model {
397            self.update_elo_draw(id_a, id_b, *k_factor);
398        }
399    }
400
401    /// Record batch selection results
402    pub fn record_batch_selection(
403        &mut self,
404        selected: &[CandidateId],
405        not_selected: &[CandidateId],
406    ) {
407        if let AggregationModel::ImplicitRanking {
408            selected_bonus,
409            not_selected_penalty,
410            ..
411        } = &self.model
412        {
413            let bonus = *selected_bonus;
414            let penalty = *not_selected_penalty;
415
416            for &id in selected {
417                self.ensure_stats(id);
418                if let Some(stats) = self.candidate_stats.get_mut(&id) {
419                    stats.times_selected += 1;
420                    stats.model_score += bonus;
421                }
422            }
423
424            for &id in not_selected {
425                self.ensure_stats(id);
426                if let Some(stats) = self.candidate_stats.get_mut(&id) {
427                    stats.times_passed += 1;
428                    stats.model_score -= penalty;
429                }
430            }
431        } else {
432            // For other models, just track selection counts
433            for &id in selected {
434                self.ensure_stats(id);
435                if let Some(stats) = self.candidate_stats.get_mut(&id) {
436                    stats.times_selected += 1;
437                }
438            }
439            for &id in not_selected {
440                self.ensure_stats(id);
441                if let Some(stats) = self.candidate_stats.get_mut(&id) {
442                    stats.times_passed += 1;
443                }
444            }
445        }
446    }
447
448    /// Update Elo ratings after a comparison
449    fn update_elo(&mut self, winner: CandidateId, loser: CandidateId, k: f64) {
450        let winner_rating = self
451            .candidate_stats
452            .get(&winner)
453            .map(|s| s.model_score)
454            .unwrap_or(1500.0);
455        let loser_rating = self
456            .candidate_stats
457            .get(&loser)
458            .map(|s| s.model_score)
459            .unwrap_or(1500.0);
460
461        // Expected scores
462        let exp_winner = 1.0 / (1.0 + 10.0_f64.powf((loser_rating - winner_rating) / 400.0));
463        let exp_loser = 1.0 - exp_winner;
464
465        // Update ratings
466        if let Some(stats) = self.candidate_stats.get_mut(&winner) {
467            stats.model_score += k * (1.0 - exp_winner);
468        }
469        if let Some(stats) = self.candidate_stats.get_mut(&loser) {
470            stats.model_score += k * (0.0 - exp_loser);
471        }
472    }
473
474    /// Update Elo ratings after a draw
475    fn update_elo_draw(&mut self, id_a: CandidateId, id_b: CandidateId, k: f64) {
476        let rating_a = self
477            .candidate_stats
478            .get(&id_a)
479            .map(|s| s.model_score)
480            .unwrap_or(1500.0);
481        let rating_b = self
482            .candidate_stats
483            .get(&id_b)
484            .map(|s| s.model_score)
485            .unwrap_or(1500.0);
486
487        // Expected scores
488        let exp_a = 1.0 / (1.0 + 10.0_f64.powf((rating_b - rating_a) / 400.0));
489        let exp_b = 1.0 - exp_a;
490
491        // Update ratings (actual = 0.5 for draw)
492        if let Some(stats) = self.candidate_stats.get_mut(&id_a) {
493            stats.model_score += k * (0.5 - exp_a);
494        }
495        if let Some(stats) = self.candidate_stats.get_mut(&id_b) {
496            stats.model_score += k * (0.5 - exp_b);
497        }
498    }
499
500    /// Recompute all fitness estimates from comparison history
501    ///
502    /// This is useful for Bradley-Terry model which uses batch MLE,
503    /// or after loading a session from checkpoint.
504    pub fn recompute_all(&mut self) -> HashMap<CandidateId, f64> {
505        match &self.model {
506            AggregationModel::BradleyTerry { optimizer, .. } => {
507                self.recompute_bradley_terry_mle(optimizer.clone());
508            }
509            AggregationModel::BradleyTerrySimple {
510                initial_strength,
511                learning_rate,
512                iterations,
513            } => {
514                self.recompute_bradley_terry_simple(*initial_strength, *learning_rate, *iterations);
515            }
516            _ => {}
517        }
518
519        // Return current fitness estimates
520        self.candidate_stats
521            .keys()
522            .filter_map(|id| self.get_fitness(id).map(|f| (*id, f)))
523            .collect()
524    }
525
526    /// Recompute Bradley-Terry using proper MLE (Newton-Raphson or MM)
527    fn recompute_bradley_terry_mle(&mut self, optimizer: BradleyTerryOptimizer) {
528        let ids: Vec<CandidateId> = self.candidate_stats.keys().copied().collect();
529        if ids.is_empty() || self.comparisons.is_empty() {
530            return;
531        }
532
533        let model = BradleyTerryModel::new(optimizer);
534        let result = model.fit(&self.comparisons, &ids);
535
536        // Update stats with MLE results
537        for (&id, &strength) in &result.strengths {
538            if let Some(stats) = self.candidate_stats.get_mut(&id) {
539                stats.model_score = strength;
540
541                // Update variance from covariance matrix
542                if let Some(&idx) = result.id_to_index.get(&id) {
543                    if idx < result.covariance.nrows() {
544                        stats.model_variance = result.covariance[(idx, idx)];
545                    }
546                }
547            }
548        }
549    }
550
551    /// Recompute Bradley-Terry using simplified iterative MM (legacy)
552    fn recompute_bradley_terry_simple(
553        &mut self,
554        initial_strength: f64,
555        learning_rate: f64,
556        iterations: usize,
557    ) {
558        // Initialize strengths
559        let ids: Vec<CandidateId> = self.candidate_stats.keys().copied().collect();
560        for &id in &ids {
561            if let Some(stats) = self.candidate_stats.get_mut(&id) {
562                stats.model_score = initial_strength;
563            }
564        }
565
566        // Iterative MM algorithm for Bradley-Terry
567        for _ in 0..iterations {
568            let mut new_scores: HashMap<CandidateId, f64> = HashMap::new();
569
570            for &id in &ids {
571                let stats = match self.candidate_stats.get(&id) {
572                    Some(s) => s,
573                    None => continue,
574                };
575
576                let wins = stats.wins as f64;
577                if wins == 0.0 {
578                    new_scores.insert(id, stats.model_score);
579                    continue;
580                }
581
582                // Compute denominator: sum of 1/(p_i + p_j) over all comparisons
583                let mut denom = 0.0;
584                for comparison in &self.comparisons {
585                    if comparison.winner == id {
586                        let other_score = self
587                            .candidate_stats
588                            .get(&comparison.loser)
589                            .map(|s| s.model_score)
590                            .unwrap_or(initial_strength);
591                        denom += 1.0 / (stats.model_score + other_score);
592                    } else if comparison.loser == id {
593                        let other_score = self
594                            .candidate_stats
595                            .get(&comparison.winner)
596                            .map(|s| s.model_score)
597                            .unwrap_or(initial_strength);
598                        denom += 1.0 / (stats.model_score + other_score);
599                    }
600                }
601
602                let new_score = if denom > 0.0 {
603                    let raw = wins / denom;
604                    // Smooth update with learning rate
605                    stats.model_score + learning_rate * (raw - stats.model_score)
606                } else {
607                    stats.model_score
608                };
609
610                new_scores.insert(id, new_score.max(0.001)); // Avoid zero strength
611            }
612
613            // Apply new scores
614            for (id, score) in new_scores {
615                if let Some(stats) = self.candidate_stats.get_mut(&id) {
616                    stats.model_score = score;
617                }
618            }
619        }
620    }
621
622    /// Process an evaluation response and return updated fitness values
623    pub fn process_response(&mut self, response: &EvaluationResponse) -> Vec<(CandidateId, f64)> {
624        match response {
625            EvaluationResponse::Ratings(ratings) => {
626                for (id, rating) in ratings {
627                    self.record_rating(*id, *rating);
628                }
629                ratings
630                    .iter()
631                    .filter_map(|(id, _)| self.get_fitness(id).map(|f| (*id, f)))
632                    .collect()
633            }
634            EvaluationResponse::PairwiseWinner(Some(winner)) => {
635                // We need both IDs to record a comparison
636                // For now, just return the winner's fitness
637                self.ensure_stats(*winner);
638                if let Some(f) = self.get_fitness(winner) {
639                    vec![(*winner, f)]
640                } else {
641                    vec![]
642                }
643            }
644            EvaluationResponse::PairwiseWinner(None) => {
645                // Tie - nothing to update without both IDs
646                vec![]
647            }
648            EvaluationResponse::BatchSelected(selected) => {
649                // Update selection counts
650                for id in selected {
651                    self.ensure_stats(*id);
652                    if let Some(stats) = self.candidate_stats.get_mut(id) {
653                        stats.times_selected += 1;
654                        if let AggregationModel::ImplicitRanking { selected_bonus, .. } =
655                            &self.model
656                        {
657                            stats.model_score += *selected_bonus;
658                        }
659                    }
660                }
661                selected
662                    .iter()
663                    .filter_map(|id| self.get_fitness(id).map(|f| (*id, f)))
664                    .collect()
665            }
666            EvaluationResponse::Skip => vec![],
667        }
668    }
669
670    /// Process a pairwise comparison with both candidate IDs
671    pub fn process_pairwise(
672        &mut self,
673        id_a: CandidateId,
674        id_b: CandidateId,
675        winner: Option<CandidateId>,
676    ) -> Vec<(CandidateId, f64)> {
677        match winner {
678            Some(w) if w == id_a => {
679                self.record_comparison(id_a, id_b);
680            }
681            Some(w) if w == id_b => {
682                self.record_comparison(id_b, id_a);
683            }
684            Some(_) => {
685                // Winner ID doesn't match either candidate
686            }
687            None => {
688                self.record_tie(id_a, id_b);
689            }
690        }
691
692        vec![id_a, id_b]
693            .into_iter()
694            .filter_map(|id| self.get_fitness(&id).map(|f| (id, f)))
695            .collect()
696    }
697
698    /// Process batch selection with full context
699    pub fn process_batch_selection(
700        &mut self,
701        all_candidates: &[CandidateId],
702        selected: &[CandidateId],
703    ) -> Vec<(CandidateId, f64)> {
704        let selected_set: std::collections::HashSet<_> = selected.iter().copied().collect();
705        let not_selected: Vec<_> = all_candidates
706            .iter()
707            .copied()
708            .filter(|id| !selected_set.contains(id))
709            .collect();
710
711        self.record_batch_selection(selected, &not_selected);
712
713        all_candidates
714            .iter()
715            .filter_map(|id| self.get_fitness(id).map(|f| (*id, f)))
716            .collect()
717    }
718
719    /// Get all candidate IDs with fitness estimates
720    pub fn all_candidates(&self) -> Vec<CandidateId> {
721        self.candidate_stats.keys().copied().collect()
722    }
723
724    /// Get the number of comparisons recorded
725    pub fn comparison_count(&self) -> usize {
726        self.comparisons.len()
727    }
728
729    /// Clear all recorded data
730    pub fn clear(&mut self) {
731        self.candidate_stats.clear();
732        self.comparisons.clear();
733    }
734}
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739
740    #[test]
741    fn test_direct_rating_aggregation() {
742        let mut agg = FitnessAggregator::new(AggregationModel::DirectRating {
743            default_rating: 5.0,
744        });
745
746        let id = CandidateId(0);
747
748        // Initially should return default
749        agg.ensure_stats(id);
750        assert_eq!(agg.get_fitness(&id), Some(5.0));
751
752        // After rating
753        agg.record_rating(id, 8.0);
754        assert_eq!(agg.get_fitness(&id), Some(8.0));
755
756        // After second rating, should average
757        agg.record_rating(id, 6.0);
758        assert_eq!(agg.get_fitness(&id), Some(7.0));
759    }
760
761    #[test]
762    fn test_elo_rating() {
763        let mut agg = FitnessAggregator::new(AggregationModel::Elo {
764            initial_rating: 1500.0,
765            k_factor: 32.0,
766        });
767
768        let id_a = CandidateId(0);
769        let id_b = CandidateId(1);
770
771        agg.ensure_stats(id_a);
772        agg.ensure_stats(id_b);
773
774        // Initial ratings should be equal
775        assert_eq!(agg.get_fitness(&id_a), Some(1500.0));
776        assert_eq!(agg.get_fitness(&id_b), Some(1500.0));
777
778        // After A beats B
779        agg.record_comparison(id_a, id_b);
780
781        let fitness_a = agg.get_fitness(&id_a).unwrap();
782        let fitness_b = agg.get_fitness(&id_b).unwrap();
783
784        // Winner should gain rating
785        assert!(fitness_a > 1500.0);
786        // Loser should lose rating
787        assert!(fitness_b < 1500.0);
788        // Total rating should be conserved
789        assert!((fitness_a + fitness_b - 3000.0).abs() < 0.01);
790    }
791
792    #[test]
793    fn test_elo_draw() {
794        let mut agg = FitnessAggregator::new(AggregationModel::Elo {
795            initial_rating: 1500.0,
796            k_factor: 32.0,
797        });
798
799        let id_a = CandidateId(0);
800        let id_b = CandidateId(1);
801
802        agg.ensure_stats(id_a);
803        agg.ensure_stats(id_b);
804
805        // After tie between equal players, ratings should stay the same
806        agg.record_tie(id_a, id_b);
807
808        let fitness_a = agg.get_fitness(&id_a).unwrap();
809        let fitness_b = agg.get_fitness(&id_b).unwrap();
810
811        assert!((fitness_a - 1500.0).abs() < 0.01);
812        assert!((fitness_b - 1500.0).abs() < 0.01);
813    }
814
815    #[test]
816    fn test_implicit_ranking() {
817        let mut agg = FitnessAggregator::new(AggregationModel::ImplicitRanking {
818            selected_bonus: 1.0,
819            not_selected_penalty: 0.5,
820            base_fitness: 5.0,
821        });
822
823        let selected = vec![CandidateId(0), CandidateId(1)];
824        let not_selected = vec![CandidateId(2), CandidateId(3)];
825
826        agg.record_batch_selection(&selected, &not_selected);
827
828        // Selected candidates should have bonus
829        assert_eq!(agg.get_fitness(&CandidateId(0)), Some(6.0));
830        assert_eq!(agg.get_fitness(&CandidateId(1)), Some(6.0));
831
832        // Not selected should have penalty
833        assert_eq!(agg.get_fitness(&CandidateId(2)), Some(4.5));
834        assert_eq!(agg.get_fitness(&CandidateId(3)), Some(4.5));
835    }
836
837    #[test]
838    fn test_bradley_terry_simple_recompute() {
839        let mut agg = FitnessAggregator::new(AggregationModel::BradleyTerrySimple {
840            initial_strength: 1.0,
841            learning_rate: 0.5,
842            iterations: 10,
843        });
844
845        // A beats B multiple times, B beats C
846        agg.ensure_stats(CandidateId(0));
847        agg.ensure_stats(CandidateId(1));
848        agg.ensure_stats(CandidateId(2));
849
850        agg.record_comparison(CandidateId(0), CandidateId(1));
851        agg.record_comparison(CandidateId(0), CandidateId(1));
852        agg.record_comparison(CandidateId(1), CandidateId(2));
853
854        let fitness = agg.recompute_all();
855
856        // A should have highest strength
857        assert!(fitness[&CandidateId(0)] > fitness[&CandidateId(1)]);
858        // B should beat C
859        assert!(fitness[&CandidateId(1)] > fitness[&CandidateId(2)]);
860    }
861
862    #[test]
863    fn test_bradley_terry_mle_recompute() {
864        use crate::interactive::bradley_terry::BradleyTerryOptimizer;
865
866        let mut agg = FitnessAggregator::new(AggregationModel::BradleyTerry {
867            initial_strength: 1.0,
868            optimizer: BradleyTerryOptimizer::default(),
869        });
870
871        // A beats B multiple times, B beats C
872        agg.ensure_stats(CandidateId(0));
873        agg.ensure_stats(CandidateId(1));
874        agg.ensure_stats(CandidateId(2));
875
876        agg.record_comparison(CandidateId(0), CandidateId(1));
877        agg.record_comparison(CandidateId(0), CandidateId(1));
878        agg.record_comparison(CandidateId(1), CandidateId(2));
879
880        let fitness = agg.recompute_all();
881
882        // A should have highest strength
883        assert!(fitness[&CandidateId(0)] > fitness[&CandidateId(1)]);
884        // B should beat C
885        assert!(fitness[&CandidateId(1)] > fitness[&CandidateId(2)]);
886
887        // MLE should also provide variance estimates
888        let estimate_a = agg.get_fitness_estimate(&CandidateId(0)).unwrap();
889        assert!(estimate_a.variance.is_finite());
890        assert!(estimate_a.observation_count > 0);
891    }
892
893    #[test]
894    fn test_fitness_estimate_direct_rating() {
895        let mut agg = FitnessAggregator::new(AggregationModel::DirectRating {
896            default_rating: 5.0,
897        });
898
899        let id = CandidateId(0);
900        agg.ensure_stats(id);
901
902        // Initially should be uninformative
903        let estimate = agg.get_fitness_estimate(&id).unwrap();
904        assert_eq!(estimate.mean, 5.0);
905        assert!(estimate.variance.is_infinite());
906
907        // After ratings, should have finite variance
908        agg.record_rating(id, 8.0);
909        agg.record_rating(id, 6.0);
910        agg.record_rating(id, 7.0);
911
912        let estimate = agg.get_fitness_estimate(&id).unwrap();
913        assert_eq!(estimate.mean, 7.0);
914        assert!(estimate.variance.is_finite());
915        assert_eq!(estimate.observation_count, 3);
916    }
917
918    #[test]
919    fn test_candidate_stats() {
920        let mut stats = CandidateStats::new(1500.0);
921
922        // Test rating tracking
923        stats.rating_sum = 24.0;
924        stats.rating_count = 3;
925        assert_eq!(stats.average_rating(), Some(8.0));
926
927        // Test win rate
928        stats.wins = 3;
929        stats.losses = 1;
930        assert_eq!(stats.total_comparisons(), 4);
931        assert_eq!(stats.win_rate(), Some(0.75));
932
933        // Test selection rate
934        stats.times_selected = 2;
935        stats.times_passed = 3;
936        assert_eq!(stats.selection_rate(), Some(0.4));
937    }
938
939    #[test]
940    fn test_process_response_ratings() {
941        let mut agg = FitnessAggregator::new(AggregationModel::DirectRating {
942            default_rating: 5.0,
943        });
944
945        let response =
946            EvaluationResponse::ratings(vec![(CandidateId(0), 8.0), (CandidateId(1), 6.0)]);
947
948        let updated = agg.process_response(&response);
949
950        assert_eq!(updated.len(), 2);
951        assert!(updated
952            .iter()
953            .any(|(id, f)| *id == CandidateId(0) && *f == 8.0));
954        assert!(updated
955            .iter()
956            .any(|(id, f)| *id == CandidateId(1) && *f == 6.0));
957    }
958
959    #[test]
960    fn test_process_batch_selection() {
961        let mut agg = FitnessAggregator::new(AggregationModel::ImplicitRanking {
962            selected_bonus: 1.0,
963            not_selected_penalty: 0.5,
964            base_fitness: 5.0,
965        });
966
967        let all = vec![
968            CandidateId(0),
969            CandidateId(1),
970            CandidateId(2),
971            CandidateId(3),
972        ];
973        let selected = vec![CandidateId(0), CandidateId(2)];
974
975        let updated = agg.process_batch_selection(&all, &selected);
976
977        assert_eq!(updated.len(), 4);
978
979        // Check selected got bonus
980        let fitness_0 = updated
981            .iter()
982            .find(|(id, _)| *id == CandidateId(0))
983            .unwrap()
984            .1;
985        assert_eq!(fitness_0, 6.0);
986
987        // Check not selected got penalty
988        let fitness_1 = updated
989            .iter()
990            .find(|(id, _)| *id == CandidateId(1))
991            .unwrap()
992            .1;
993        assert_eq!(fitness_1, 4.5);
994    }
995}