1use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21use super::bradley_terry::{BradleyTerryModel, BradleyTerryOptimizer};
22use super::evaluator::{CandidateId, EvaluationResponse};
23use super::uncertainty::FitnessEstimate;
24
25#[derive(Clone, Debug, Serialize, Deserialize)]
27pub enum AggregationModel {
28 DirectRating {
33 default_rating: f64,
35 },
36
37 Elo {
42 initial_rating: f64,
44 k_factor: f64,
46 },
47
48 BradleyTerry {
54 initial_strength: f64,
56 #[serde(default)]
58 optimizer: BradleyTerryOptimizer,
59 },
60
61 #[serde(alias = "BradleyTerryLegacy")]
65 BradleyTerrySimple {
66 initial_strength: f64,
68 learning_rate: f64,
70 iterations: usize,
72 },
73
74 ImplicitRanking {
79 selected_bonus: f64,
81 not_selected_penalty: f64,
83 base_fitness: f64,
85 },
86}
87
88impl Default for AggregationModel {
89 fn default() -> Self {
90 Self::DirectRating {
91 default_rating: 5.0,
92 }
93 }
94}
95
96#[derive(Clone, Debug, Default, Serialize, Deserialize)]
98pub struct CandidateStats {
99 pub rating_sum: f64,
101 #[serde(default)]
103 pub rating_sum_squares: f64,
104 pub rating_count: usize,
106 pub model_score: f64,
108 #[serde(default = "default_variance")]
110 pub model_variance: f64,
111 pub wins: usize,
113 pub losses: usize,
115 pub ties: usize,
117 pub times_selected: usize,
119 pub times_passed: usize,
121}
122
123fn default_variance() -> f64 {
124 f64::INFINITY
125}
126
127impl CandidateStats {
128 pub fn new(initial_score: f64) -> Self {
130 Self {
131 model_score: initial_score,
132 model_variance: f64::INFINITY,
133 ..Default::default()
134 }
135 }
136
137 pub fn average_rating(&self) -> Option<f64> {
139 if self.rating_count > 0 {
140 Some(self.rating_sum / self.rating_count as f64)
141 } else {
142 None
143 }
144 }
145
146 pub fn rating_variance(&self) -> Option<f64> {
148 if self.rating_count < 2 {
149 return None;
150 }
151 let n = self.rating_count as f64;
152 let mean = self.rating_sum / n;
153 let var = (self.rating_sum_squares / n) - (mean * mean);
155 Some(var * n / (n - 1.0))
157 }
158
159 pub fn rating_variance_of_mean(&self) -> Option<f64> {
161 self.rating_variance()
162 .map(|var| var / self.rating_count as f64)
163 }
164
165 pub fn total_comparisons(&self) -> usize {
167 self.wins + self.losses + self.ties
168 }
169
170 pub fn win_rate(&self) -> Option<f64> {
172 let total = self.total_comparisons();
173 if total > 0 {
174 Some(self.wins as f64 / total as f64)
175 } else {
176 None
177 }
178 }
179
180 pub fn selection_rate(&self) -> Option<f64> {
182 let total = self.times_selected + self.times_passed;
183 if total > 0 {
184 Some(self.times_selected as f64 / total as f64)
185 } else {
186 None
187 }
188 }
189}
190
191#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct ComparisonRecord {
194 pub winner: CandidateId,
196 pub loser: CandidateId,
198 pub generation: usize,
200}
201
202#[derive(Clone, Debug, Serialize, Deserialize)]
204pub struct FitnessAggregator {
205 model: AggregationModel,
207 candidate_stats: HashMap<CandidateId, CandidateStats>,
209 comparisons: Vec<ComparisonRecord>,
211 current_generation: usize,
213}
214
215impl FitnessAggregator {
216 pub fn new(model: AggregationModel) -> Self {
218 Self {
219 model,
220 candidate_stats: HashMap::new(),
221 comparisons: Vec::new(),
222 current_generation: 0,
223 }
224 }
225
226 pub fn model(&self) -> &AggregationModel {
228 &self.model
229 }
230
231 pub fn set_generation(&mut self, generation: usize) {
233 self.current_generation = generation;
234 }
235
236 fn ensure_stats(&mut self, id: CandidateId) {
238 if !self.candidate_stats.contains_key(&id) {
239 let initial_score = match &self.model {
240 AggregationModel::DirectRating { default_rating } => *default_rating,
241 AggregationModel::Elo { initial_rating, .. } => *initial_rating,
242 AggregationModel::BradleyTerry {
243 initial_strength, ..
244 } => *initial_strength,
245 AggregationModel::BradleyTerrySimple {
246 initial_strength, ..
247 } => *initial_strength,
248 AggregationModel::ImplicitRanking { base_fitness, .. } => *base_fitness,
249 };
250 self.candidate_stats
251 .insert(id, CandidateStats::new(initial_score));
252 }
253 }
254
255 pub fn get_stats(&self, id: &CandidateId) -> Option<&CandidateStats> {
257 self.candidate_stats.get(id)
258 }
259
260 pub fn get_fitness(&self, id: &CandidateId) -> Option<f64> {
264 let stats = self.candidate_stats.get(id)?;
265
266 Some(match &self.model {
267 AggregationModel::DirectRating { default_rating } => {
268 stats.average_rating().unwrap_or(*default_rating)
269 }
270 AggregationModel::Elo { .. } => stats.model_score,
271 AggregationModel::BradleyTerry { .. } => stats.model_score,
272 AggregationModel::BradleyTerrySimple { .. } => stats.model_score,
273 AggregationModel::ImplicitRanking { .. } => {
274 stats.model_score
276 }
277 })
278 }
279
280 pub fn get_fitness_estimate(&self, id: &CandidateId) -> Option<FitnessEstimate> {
285 let stats = self.candidate_stats.get(id)?;
286
287 Some(match &self.model {
288 AggregationModel::DirectRating { default_rating } => {
289 if stats.rating_count == 0 {
290 FitnessEstimate::uninformative(*default_rating)
291 } else {
292 let mean = stats.rating_sum / stats.rating_count as f64;
293 let variance = stats.rating_variance_of_mean().unwrap_or(f64::INFINITY);
294 FitnessEstimate::new(mean, variance, stats.rating_count)
295 }
296 }
297 AggregationModel::Elo { k_factor, .. } => {
298 let n_games = stats.total_comparisons();
300 let variance = if n_games == 0 {
301 f64::INFINITY
302 } else {
303 let base_var = k_factor * k_factor * 0.25; base_var / n_games as f64
306 };
307 FitnessEstimate::new(stats.model_score, variance, n_games)
308 }
309 AggregationModel::BradleyTerry { .. } | AggregationModel::BradleyTerrySimple { .. } => {
310 let n_comparisons = stats.total_comparisons();
312 let variance = if stats.model_variance.is_finite() {
313 stats.model_variance
314 } else if n_comparisons == 0 {
315 f64::INFINITY
316 } else {
317 1.0 / n_comparisons as f64
319 };
320 FitnessEstimate::new(stats.model_score, variance, n_comparisons)
321 }
322 AggregationModel::ImplicitRanking { .. } => {
323 let n = stats.times_selected + stats.times_passed;
325 if n == 0 {
326 FitnessEstimate::uninformative(stats.model_score)
327 } else {
328 let p = stats.times_selected as f64 / n as f64;
329 let variance = p * (1.0 - p) / n as f64;
330 FitnessEstimate::new(stats.model_score, variance, n)
331 }
332 }
333 })
334 }
335
336 pub fn comparisons(&self) -> &[ComparisonRecord] {
338 &self.comparisons
339 }
340
341 pub fn record_rating(&mut self, id: CandidateId, rating: f64) {
343 self.ensure_stats(id);
344 if let Some(stats) = self.candidate_stats.get_mut(&id) {
345 stats.rating_sum += rating;
346 stats.rating_sum_squares += rating * rating;
347 stats.rating_count += 1;
348 }
349 }
350
351 pub fn record_comparison(&mut self, winner: CandidateId, loser: CandidateId) {
353 self.ensure_stats(winner);
354 self.ensure_stats(loser);
355
356 if let Some(winner_stats) = self.candidate_stats.get_mut(&winner) {
358 winner_stats.wins += 1;
359 }
360 if let Some(loser_stats) = self.candidate_stats.get_mut(&loser) {
361 loser_stats.losses += 1;
362 }
363
364 self.comparisons.push(ComparisonRecord {
366 winner,
367 loser,
368 generation: self.current_generation,
369 });
370
371 match &self.model {
373 AggregationModel::Elo { k_factor, .. } => {
374 self.update_elo(winner, loser, *k_factor);
375 }
376 AggregationModel::BradleyTerry { .. } => {
377 }
379 _ => {}
380 }
381 }
382
383 pub fn record_tie(&mut self, id_a: CandidateId, id_b: CandidateId) {
385 self.ensure_stats(id_a);
386 self.ensure_stats(id_b);
387
388 if let Some(stats) = self.candidate_stats.get_mut(&id_a) {
389 stats.ties += 1;
390 }
391 if let Some(stats) = self.candidate_stats.get_mut(&id_b) {
392 stats.ties += 1;
393 }
394
395 if let AggregationModel::Elo { k_factor, .. } = &self.model {
397 self.update_elo_draw(id_a, id_b, *k_factor);
398 }
399 }
400
401 pub fn record_batch_selection(
403 &mut self,
404 selected: &[CandidateId],
405 not_selected: &[CandidateId],
406 ) {
407 if let AggregationModel::ImplicitRanking {
408 selected_bonus,
409 not_selected_penalty,
410 ..
411 } = &self.model
412 {
413 let bonus = *selected_bonus;
414 let penalty = *not_selected_penalty;
415
416 for &id in selected {
417 self.ensure_stats(id);
418 if let Some(stats) = self.candidate_stats.get_mut(&id) {
419 stats.times_selected += 1;
420 stats.model_score += bonus;
421 }
422 }
423
424 for &id in not_selected {
425 self.ensure_stats(id);
426 if let Some(stats) = self.candidate_stats.get_mut(&id) {
427 stats.times_passed += 1;
428 stats.model_score -= penalty;
429 }
430 }
431 } else {
432 for &id in selected {
434 self.ensure_stats(id);
435 if let Some(stats) = self.candidate_stats.get_mut(&id) {
436 stats.times_selected += 1;
437 }
438 }
439 for &id in not_selected {
440 self.ensure_stats(id);
441 if let Some(stats) = self.candidate_stats.get_mut(&id) {
442 stats.times_passed += 1;
443 }
444 }
445 }
446 }
447
448 fn update_elo(&mut self, winner: CandidateId, loser: CandidateId, k: f64) {
450 let winner_rating = self
451 .candidate_stats
452 .get(&winner)
453 .map(|s| s.model_score)
454 .unwrap_or(1500.0);
455 let loser_rating = self
456 .candidate_stats
457 .get(&loser)
458 .map(|s| s.model_score)
459 .unwrap_or(1500.0);
460
461 let exp_winner = 1.0 / (1.0 + 10.0_f64.powf((loser_rating - winner_rating) / 400.0));
463 let exp_loser = 1.0 - exp_winner;
464
465 if let Some(stats) = self.candidate_stats.get_mut(&winner) {
467 stats.model_score += k * (1.0 - exp_winner);
468 }
469 if let Some(stats) = self.candidate_stats.get_mut(&loser) {
470 stats.model_score += k * (0.0 - exp_loser);
471 }
472 }
473
474 fn update_elo_draw(&mut self, id_a: CandidateId, id_b: CandidateId, k: f64) {
476 let rating_a = self
477 .candidate_stats
478 .get(&id_a)
479 .map(|s| s.model_score)
480 .unwrap_or(1500.0);
481 let rating_b = self
482 .candidate_stats
483 .get(&id_b)
484 .map(|s| s.model_score)
485 .unwrap_or(1500.0);
486
487 let exp_a = 1.0 / (1.0 + 10.0_f64.powf((rating_b - rating_a) / 400.0));
489 let exp_b = 1.0 - exp_a;
490
491 if let Some(stats) = self.candidate_stats.get_mut(&id_a) {
493 stats.model_score += k * (0.5 - exp_a);
494 }
495 if let Some(stats) = self.candidate_stats.get_mut(&id_b) {
496 stats.model_score += k * (0.5 - exp_b);
497 }
498 }
499
500 pub fn recompute_all(&mut self) -> HashMap<CandidateId, f64> {
505 match &self.model {
506 AggregationModel::BradleyTerry { optimizer, .. } => {
507 self.recompute_bradley_terry_mle(optimizer.clone());
508 }
509 AggregationModel::BradleyTerrySimple {
510 initial_strength,
511 learning_rate,
512 iterations,
513 } => {
514 self.recompute_bradley_terry_simple(*initial_strength, *learning_rate, *iterations);
515 }
516 _ => {}
517 }
518
519 self.candidate_stats
521 .keys()
522 .filter_map(|id| self.get_fitness(id).map(|f| (*id, f)))
523 .collect()
524 }
525
526 fn recompute_bradley_terry_mle(&mut self, optimizer: BradleyTerryOptimizer) {
528 let ids: Vec<CandidateId> = self.candidate_stats.keys().copied().collect();
529 if ids.is_empty() || self.comparisons.is_empty() {
530 return;
531 }
532
533 let model = BradleyTerryModel::new(optimizer);
534 let result = model.fit(&self.comparisons, &ids);
535
536 for (&id, &strength) in &result.strengths {
538 if let Some(stats) = self.candidate_stats.get_mut(&id) {
539 stats.model_score = strength;
540
541 if let Some(&idx) = result.id_to_index.get(&id) {
543 if idx < result.covariance.nrows() {
544 stats.model_variance = result.covariance[(idx, idx)];
545 }
546 }
547 }
548 }
549 }
550
551 fn recompute_bradley_terry_simple(
553 &mut self,
554 initial_strength: f64,
555 learning_rate: f64,
556 iterations: usize,
557 ) {
558 let ids: Vec<CandidateId> = self.candidate_stats.keys().copied().collect();
560 for &id in &ids {
561 if let Some(stats) = self.candidate_stats.get_mut(&id) {
562 stats.model_score = initial_strength;
563 }
564 }
565
566 for _ in 0..iterations {
568 let mut new_scores: HashMap<CandidateId, f64> = HashMap::new();
569
570 for &id in &ids {
571 let stats = match self.candidate_stats.get(&id) {
572 Some(s) => s,
573 None => continue,
574 };
575
576 let wins = stats.wins as f64;
577 if wins == 0.0 {
578 new_scores.insert(id, stats.model_score);
579 continue;
580 }
581
582 let mut denom = 0.0;
584 for comparison in &self.comparisons {
585 if comparison.winner == id {
586 let other_score = self
587 .candidate_stats
588 .get(&comparison.loser)
589 .map(|s| s.model_score)
590 .unwrap_or(initial_strength);
591 denom += 1.0 / (stats.model_score + other_score);
592 } else if comparison.loser == id {
593 let other_score = self
594 .candidate_stats
595 .get(&comparison.winner)
596 .map(|s| s.model_score)
597 .unwrap_or(initial_strength);
598 denom += 1.0 / (stats.model_score + other_score);
599 }
600 }
601
602 let new_score = if denom > 0.0 {
603 let raw = wins / denom;
604 stats.model_score + learning_rate * (raw - stats.model_score)
606 } else {
607 stats.model_score
608 };
609
610 new_scores.insert(id, new_score.max(0.001)); }
612
613 for (id, score) in new_scores {
615 if let Some(stats) = self.candidate_stats.get_mut(&id) {
616 stats.model_score = score;
617 }
618 }
619 }
620 }
621
622 pub fn process_response(&mut self, response: &EvaluationResponse) -> Vec<(CandidateId, f64)> {
624 match response {
625 EvaluationResponse::Ratings(ratings) => {
626 for (id, rating) in ratings {
627 self.record_rating(*id, *rating);
628 }
629 ratings
630 .iter()
631 .filter_map(|(id, _)| self.get_fitness(id).map(|f| (*id, f)))
632 .collect()
633 }
634 EvaluationResponse::PairwiseWinner(Some(winner)) => {
635 self.ensure_stats(*winner);
638 if let Some(f) = self.get_fitness(winner) {
639 vec![(*winner, f)]
640 } else {
641 vec![]
642 }
643 }
644 EvaluationResponse::PairwiseWinner(None) => {
645 vec![]
647 }
648 EvaluationResponse::BatchSelected(selected) => {
649 for id in selected {
651 self.ensure_stats(*id);
652 if let Some(stats) = self.candidate_stats.get_mut(id) {
653 stats.times_selected += 1;
654 if let AggregationModel::ImplicitRanking { selected_bonus, .. } =
655 &self.model
656 {
657 stats.model_score += *selected_bonus;
658 }
659 }
660 }
661 selected
662 .iter()
663 .filter_map(|id| self.get_fitness(id).map(|f| (*id, f)))
664 .collect()
665 }
666 EvaluationResponse::Skip => vec![],
667 }
668 }
669
670 pub fn process_pairwise(
672 &mut self,
673 id_a: CandidateId,
674 id_b: CandidateId,
675 winner: Option<CandidateId>,
676 ) -> Vec<(CandidateId, f64)> {
677 match winner {
678 Some(w) if w == id_a => {
679 self.record_comparison(id_a, id_b);
680 }
681 Some(w) if w == id_b => {
682 self.record_comparison(id_b, id_a);
683 }
684 Some(_) => {
685 }
687 None => {
688 self.record_tie(id_a, id_b);
689 }
690 }
691
692 vec![id_a, id_b]
693 .into_iter()
694 .filter_map(|id| self.get_fitness(&id).map(|f| (id, f)))
695 .collect()
696 }
697
698 pub fn process_batch_selection(
700 &mut self,
701 all_candidates: &[CandidateId],
702 selected: &[CandidateId],
703 ) -> Vec<(CandidateId, f64)> {
704 let selected_set: std::collections::HashSet<_> = selected.iter().copied().collect();
705 let not_selected: Vec<_> = all_candidates
706 .iter()
707 .copied()
708 .filter(|id| !selected_set.contains(id))
709 .collect();
710
711 self.record_batch_selection(selected, ¬_selected);
712
713 all_candidates
714 .iter()
715 .filter_map(|id| self.get_fitness(id).map(|f| (*id, f)))
716 .collect()
717 }
718
719 pub fn all_candidates(&self) -> Vec<CandidateId> {
721 self.candidate_stats.keys().copied().collect()
722 }
723
724 pub fn comparison_count(&self) -> usize {
726 self.comparisons.len()
727 }
728
729 pub fn clear(&mut self) {
731 self.candidate_stats.clear();
732 self.comparisons.clear();
733 }
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739
740 #[test]
741 fn test_direct_rating_aggregation() {
742 let mut agg = FitnessAggregator::new(AggregationModel::DirectRating {
743 default_rating: 5.0,
744 });
745
746 let id = CandidateId(0);
747
748 agg.ensure_stats(id);
750 assert_eq!(agg.get_fitness(&id), Some(5.0));
751
752 agg.record_rating(id, 8.0);
754 assert_eq!(agg.get_fitness(&id), Some(8.0));
755
756 agg.record_rating(id, 6.0);
758 assert_eq!(agg.get_fitness(&id), Some(7.0));
759 }
760
761 #[test]
762 fn test_elo_rating() {
763 let mut agg = FitnessAggregator::new(AggregationModel::Elo {
764 initial_rating: 1500.0,
765 k_factor: 32.0,
766 });
767
768 let id_a = CandidateId(0);
769 let id_b = CandidateId(1);
770
771 agg.ensure_stats(id_a);
772 agg.ensure_stats(id_b);
773
774 assert_eq!(agg.get_fitness(&id_a), Some(1500.0));
776 assert_eq!(agg.get_fitness(&id_b), Some(1500.0));
777
778 agg.record_comparison(id_a, id_b);
780
781 let fitness_a = agg.get_fitness(&id_a).unwrap();
782 let fitness_b = agg.get_fitness(&id_b).unwrap();
783
784 assert!(fitness_a > 1500.0);
786 assert!(fitness_b < 1500.0);
788 assert!((fitness_a + fitness_b - 3000.0).abs() < 0.01);
790 }
791
792 #[test]
793 fn test_elo_draw() {
794 let mut agg = FitnessAggregator::new(AggregationModel::Elo {
795 initial_rating: 1500.0,
796 k_factor: 32.0,
797 });
798
799 let id_a = CandidateId(0);
800 let id_b = CandidateId(1);
801
802 agg.ensure_stats(id_a);
803 agg.ensure_stats(id_b);
804
805 agg.record_tie(id_a, id_b);
807
808 let fitness_a = agg.get_fitness(&id_a).unwrap();
809 let fitness_b = agg.get_fitness(&id_b).unwrap();
810
811 assert!((fitness_a - 1500.0).abs() < 0.01);
812 assert!((fitness_b - 1500.0).abs() < 0.01);
813 }
814
815 #[test]
816 fn test_implicit_ranking() {
817 let mut agg = FitnessAggregator::new(AggregationModel::ImplicitRanking {
818 selected_bonus: 1.0,
819 not_selected_penalty: 0.5,
820 base_fitness: 5.0,
821 });
822
823 let selected = vec![CandidateId(0), CandidateId(1)];
824 let not_selected = vec![CandidateId(2), CandidateId(3)];
825
826 agg.record_batch_selection(&selected, ¬_selected);
827
828 assert_eq!(agg.get_fitness(&CandidateId(0)), Some(6.0));
830 assert_eq!(agg.get_fitness(&CandidateId(1)), Some(6.0));
831
832 assert_eq!(agg.get_fitness(&CandidateId(2)), Some(4.5));
834 assert_eq!(agg.get_fitness(&CandidateId(3)), Some(4.5));
835 }
836
837 #[test]
838 fn test_bradley_terry_simple_recompute() {
839 let mut agg = FitnessAggregator::new(AggregationModel::BradleyTerrySimple {
840 initial_strength: 1.0,
841 learning_rate: 0.5,
842 iterations: 10,
843 });
844
845 agg.ensure_stats(CandidateId(0));
847 agg.ensure_stats(CandidateId(1));
848 agg.ensure_stats(CandidateId(2));
849
850 agg.record_comparison(CandidateId(0), CandidateId(1));
851 agg.record_comparison(CandidateId(0), CandidateId(1));
852 agg.record_comparison(CandidateId(1), CandidateId(2));
853
854 let fitness = agg.recompute_all();
855
856 assert!(fitness[&CandidateId(0)] > fitness[&CandidateId(1)]);
858 assert!(fitness[&CandidateId(1)] > fitness[&CandidateId(2)]);
860 }
861
862 #[test]
863 fn test_bradley_terry_mle_recompute() {
864 use crate::interactive::bradley_terry::BradleyTerryOptimizer;
865
866 let mut agg = FitnessAggregator::new(AggregationModel::BradleyTerry {
867 initial_strength: 1.0,
868 optimizer: BradleyTerryOptimizer::default(),
869 });
870
871 agg.ensure_stats(CandidateId(0));
873 agg.ensure_stats(CandidateId(1));
874 agg.ensure_stats(CandidateId(2));
875
876 agg.record_comparison(CandidateId(0), CandidateId(1));
877 agg.record_comparison(CandidateId(0), CandidateId(1));
878 agg.record_comparison(CandidateId(1), CandidateId(2));
879
880 let fitness = agg.recompute_all();
881
882 assert!(fitness[&CandidateId(0)] > fitness[&CandidateId(1)]);
884 assert!(fitness[&CandidateId(1)] > fitness[&CandidateId(2)]);
886
887 let estimate_a = agg.get_fitness_estimate(&CandidateId(0)).unwrap();
889 assert!(estimate_a.variance.is_finite());
890 assert!(estimate_a.observation_count > 0);
891 }
892
893 #[test]
894 fn test_fitness_estimate_direct_rating() {
895 let mut agg = FitnessAggregator::new(AggregationModel::DirectRating {
896 default_rating: 5.0,
897 });
898
899 let id = CandidateId(0);
900 agg.ensure_stats(id);
901
902 let estimate = agg.get_fitness_estimate(&id).unwrap();
904 assert_eq!(estimate.mean, 5.0);
905 assert!(estimate.variance.is_infinite());
906
907 agg.record_rating(id, 8.0);
909 agg.record_rating(id, 6.0);
910 agg.record_rating(id, 7.0);
911
912 let estimate = agg.get_fitness_estimate(&id).unwrap();
913 assert_eq!(estimate.mean, 7.0);
914 assert!(estimate.variance.is_finite());
915 assert_eq!(estimate.observation_count, 3);
916 }
917
918 #[test]
919 fn test_candidate_stats() {
920 let mut stats = CandidateStats::new(1500.0);
921
922 stats.rating_sum = 24.0;
924 stats.rating_count = 3;
925 assert_eq!(stats.average_rating(), Some(8.0));
926
927 stats.wins = 3;
929 stats.losses = 1;
930 assert_eq!(stats.total_comparisons(), 4);
931 assert_eq!(stats.win_rate(), Some(0.75));
932
933 stats.times_selected = 2;
935 stats.times_passed = 3;
936 assert_eq!(stats.selection_rate(), Some(0.4));
937 }
938
939 #[test]
940 fn test_process_response_ratings() {
941 let mut agg = FitnessAggregator::new(AggregationModel::DirectRating {
942 default_rating: 5.0,
943 });
944
945 let response =
946 EvaluationResponse::ratings(vec![(CandidateId(0), 8.0), (CandidateId(1), 6.0)]);
947
948 let updated = agg.process_response(&response);
949
950 assert_eq!(updated.len(), 2);
951 assert!(updated
952 .iter()
953 .any(|(id, f)| *id == CandidateId(0) && *f == 8.0));
954 assert!(updated
955 .iter()
956 .any(|(id, f)| *id == CandidateId(1) && *f == 6.0));
957 }
958
959 #[test]
960 fn test_process_batch_selection() {
961 let mut agg = FitnessAggregator::new(AggregationModel::ImplicitRanking {
962 selected_bonus: 1.0,
963 not_selected_penalty: 0.5,
964 base_fitness: 5.0,
965 });
966
967 let all = vec![
968 CandidateId(0),
969 CandidateId(1),
970 CandidateId(2),
971 CandidateId(3),
972 ];
973 let selected = vec![CandidateId(0), CandidateId(2)];
974
975 let updated = agg.process_batch_selection(&all, &selected);
976
977 assert_eq!(updated.len(), 4);
978
979 let fitness_0 = updated
981 .iter()
982 .find(|(id, _)| *id == CandidateId(0))
983 .unwrap()
984 .1;
985 assert_eq!(fitness_0, 6.0);
986
987 let fitness_1 = updated
989 .iter()
990 .find(|(id, _)| *id == CandidateId(1))
991 .unwrap()
992 .1;
993 assert_eq!(fitness_1, 4.5);
994 }
995}