1use rand::prelude::*;
29use serde::{Deserialize, Serialize};
30
31use super::aggregation::FitnessAggregator;
32use super::evaluator::Candidate;
33use super::uncertainty::FitnessEstimate;
34use crate::genome::traits::EvolutionaryGenome;
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
38pub enum SelectionStrategy {
39 Sequential,
44
45 UncertaintySampling {
51 uncertainty_weight: f64,
54 },
55
56 ExpectedInformationGain {
62 temperature: f64,
65 },
66
67 CoverageAware {
73 min_evaluations: usize,
75 exploration_bonus: f64,
77 },
78}
79
80impl Default for SelectionStrategy {
81 fn default() -> Self {
82 Self::Sequential
83 }
84}
85
86impl SelectionStrategy {
87 pub fn uncertainty_sampling(uncertainty_weight: f64) -> Self {
89 Self::UncertaintySampling { uncertainty_weight }
90 }
91
92 pub fn information_gain(temperature: f64) -> Self {
94 Self::ExpectedInformationGain { temperature }
95 }
96
97 pub fn coverage_aware(min_evaluations: usize, exploration_bonus: f64) -> Self {
99 Self::CoverageAware {
100 min_evaluations,
101 exploration_bonus,
102 }
103 }
104
105 pub fn select_batch<G, R>(
118 &self,
119 candidates: &[Candidate<G>],
120 aggregator: &FitnessAggregator,
121 batch_size: usize,
122 rng: &mut R,
123 ) -> Vec<usize>
124 where
125 G: EvolutionaryGenome,
126 R: Rng,
127 {
128 if candidates.is_empty() || batch_size == 0 {
129 return vec![];
130 }
131
132 let batch_size = batch_size.min(candidates.len());
133
134 match self {
135 Self::Sequential => self.select_sequential(candidates, batch_size),
136 Self::UncertaintySampling { uncertainty_weight } => {
137 self.select_by_uncertainty(candidates, aggregator, batch_size, *uncertainty_weight)
138 }
139 Self::ExpectedInformationGain { temperature } => self.select_by_information_gain(
140 candidates,
141 aggregator,
142 batch_size,
143 *temperature,
144 rng,
145 ),
146 Self::CoverageAware {
147 min_evaluations,
148 exploration_bonus,
149 } => self.select_coverage_aware(
150 candidates,
151 aggregator,
152 batch_size,
153 *min_evaluations,
154 *exploration_bonus,
155 ),
156 }
157 }
158
159 pub fn select_pair<G, R>(
171 &self,
172 candidates: &[Candidate<G>],
173 aggregator: &FitnessAggregator,
174 rng: &mut R,
175 ) -> Option<(usize, usize)>
176 where
177 G: EvolutionaryGenome,
178 R: Rng,
179 {
180 if candidates.len() < 2 {
181 return None;
182 }
183
184 match self {
185 Self::Sequential => {
186 Some((0, 1))
188 }
189 Self::UncertaintySampling { .. } => {
190 let scores = self.compute_uncertainty_scores(candidates, aggregator);
192 let mut indices: Vec<usize> = (0..candidates.len()).collect();
193 indices.sort_by(|&a, &b| {
194 scores[b]
195 .partial_cmp(&scores[a])
196 .unwrap_or(std::cmp::Ordering::Equal)
197 });
198 Some((indices[0], indices[1]))
199 }
200 Self::ExpectedInformationGain { temperature } => {
201 self.select_pair_by_information_gain(candidates, aggregator, *temperature, rng)
202 }
203 Self::CoverageAware {
204 min_evaluations, ..
205 } => {
206 let mut indices: Vec<(usize, usize)> = candidates
208 .iter()
209 .enumerate()
210 .map(|(i, c)| (i, c.evaluation_count))
211 .collect();
212 indices.sort_by_key(|&(_, count)| count);
213
214 let a = indices[0].0;
215 let b = if indices.len() > 1 {
216 let a_eval = candidates[a].evaluation_count;
218 if a_eval < *min_evaluations {
219 indices[1].0
221 } else {
222 self.find_informative_pair(candidates, aggregator, rng)
224 }
225 } else {
226 return None;
227 };
228 Some((a, b))
229 }
230 }
231 }
232
233 fn select_sequential<G>(&self, candidates: &[Candidate<G>], batch_size: usize) -> Vec<usize>
235 where
236 G: EvolutionaryGenome,
237 {
238 let mut selected: Vec<usize> = candidates
240 .iter()
241 .enumerate()
242 .filter(|(_, c)| c.evaluation_count == 0)
243 .take(batch_size)
244 .map(|(i, _)| i)
245 .collect();
246
247 if selected.len() < batch_size {
249 for i in 0..candidates.len() {
250 if selected.len() >= batch_size {
251 break;
252 }
253 if !selected.contains(&i) {
254 selected.push(i);
255 }
256 }
257 }
258
259 selected
260 }
261
262 fn compute_uncertainty_scores<G>(
264 &self,
265 candidates: &[Candidate<G>],
266 aggregator: &FitnessAggregator,
267 ) -> Vec<f64>
268 where
269 G: EvolutionaryGenome,
270 {
271 candidates
272 .iter()
273 .map(|c| {
274 aggregator
275 .get_fitness_estimate(&c.id)
276 .map(|e| {
277 if e.variance.is_infinite() {
278 f64::MAX } else {
280 e.variance
281 }
282 })
283 .unwrap_or(f64::MAX)
284 })
285 .collect()
286 }
287
288 fn select_by_uncertainty<G>(
290 &self,
291 candidates: &[Candidate<G>],
292 aggregator: &FitnessAggregator,
293 batch_size: usize,
294 uncertainty_weight: f64,
295 ) -> Vec<usize>
296 where
297 G: EvolutionaryGenome,
298 {
299 let mut scores: Vec<(usize, f64)> = candidates
300 .iter()
301 .enumerate()
302 .map(|(i, c)| {
303 let uncertainty = aggregator
304 .get_fitness_estimate(&c.id)
305 .map(|e| {
306 if e.variance.is_infinite() {
307 f64::MAX
308 } else {
309 e.variance
310 }
311 })
312 .unwrap_or(f64::MAX);
313
314 let coverage_bonus = 1.0 / (c.evaluation_count as f64 + 1.0);
316
317 let score = uncertainty_weight * uncertainty + coverage_bonus;
318 (i, score)
319 })
320 .collect();
321
322 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
324
325 scores
326 .into_iter()
327 .take(batch_size)
328 .map(|(i, _)| i)
329 .collect()
330 }
331
332 fn select_by_information_gain<G, R>(
334 &self,
335 candidates: &[Candidate<G>],
336 aggregator: &FitnessAggregator,
337 batch_size: usize,
338 temperature: f64,
339 rng: &mut R,
340 ) -> Vec<usize>
341 where
342 G: EvolutionaryGenome,
343 R: Rng,
344 {
345 let estimates: Vec<Option<FitnessEstimate>> = candidates
348 .iter()
349 .map(|c| aggregator.get_fitness_estimate(&c.id))
350 .collect();
351
352 let mut scores: Vec<(usize, f64)> = candidates
354 .iter()
355 .enumerate()
356 .map(|(i, _)| {
357 let my_est = &estimates[i];
358 let score = estimates
359 .iter()
360 .enumerate()
361 .filter(|(j, _)| *j != i)
362 .map(|(_, other_est)| pairwise_entropy(my_est.as_ref(), other_est.as_ref()))
363 .sum::<f64>();
364 (i, score)
365 })
366 .collect();
367
368 if temperature > 0.0 {
369 let max_score = scores
371 .iter()
372 .map(|(_, s)| *s)
373 .fold(f64::NEG_INFINITY, f64::max);
374 let weights: Vec<f64> = scores
375 .iter()
376 .map(|(_, s)| ((s - max_score) / temperature).exp())
377 .collect();
378 let total: f64 = weights.iter().sum();
379
380 let mut selected = Vec::with_capacity(batch_size);
381 let mut remaining: Vec<(usize, f64)> = scores
382 .iter()
383 .zip(weights.iter())
384 .map(|((i, _), w)| (*i, *w / total))
385 .collect();
386
387 for _ in 0..batch_size {
388 if remaining.is_empty() {
389 break;
390 }
391
392 let r: f64 = rng.gen();
393 let mut cumsum = 0.0;
394 let mut chosen_idx = 0;
395
396 for (idx, (_, w)) in remaining.iter().enumerate() {
397 cumsum += w;
398 if r < cumsum {
399 chosen_idx = idx;
400 break;
401 }
402 }
403
404 let (i, _) = remaining.remove(chosen_idx);
405 selected.push(i);
406
407 let new_total: f64 = remaining.iter().map(|(_, w)| w).sum();
409 if new_total > 0.0 {
410 for (_, w) in &mut remaining {
411 *w /= new_total;
412 }
413 }
414 }
415
416 selected
417 } else {
418 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
420 scores
421 .into_iter()
422 .take(batch_size)
423 .map(|(i, _)| i)
424 .collect()
425 }
426 }
427
428 fn select_pair_by_information_gain<G, R>(
430 &self,
431 candidates: &[Candidate<G>],
432 aggregator: &FitnessAggregator,
433 temperature: f64,
434 rng: &mut R,
435 ) -> Option<(usize, usize)>
436 where
437 G: EvolutionaryGenome,
438 R: Rng,
439 {
440 let n = candidates.len();
441 if n < 2 {
442 return None;
443 }
444
445 let estimates: Vec<Option<FitnessEstimate>> = candidates
446 .iter()
447 .map(|c| aggregator.get_fitness_estimate(&c.id))
448 .collect();
449
450 let mut pair_scores: Vec<((usize, usize), f64)> = Vec::new();
452
453 for i in 0..n {
454 for j in (i + 1)..n {
455 let entropy = pairwise_entropy(estimates[i].as_ref(), estimates[j].as_ref());
456 pair_scores.push(((i, j), entropy));
457 }
458 }
459
460 if pair_scores.is_empty() {
461 return Some((0, 1));
462 }
463
464 if temperature > 0.0 {
465 let max_score = pair_scores
467 .iter()
468 .map(|(_, s)| *s)
469 .fold(f64::NEG_INFINITY, f64::max);
470 let weights: Vec<f64> = pair_scores
471 .iter()
472 .map(|(_, s)| ((s - max_score) / temperature).exp())
473 .collect();
474 let total: f64 = weights.iter().sum();
475
476 let r: f64 = rng.gen();
477 let mut cumsum = 0.0;
478
479 for ((pair, _), w) in pair_scores.iter().zip(weights.iter()) {
480 cumsum += w / total;
481 if r < cumsum {
482 return Some(*pair);
483 }
484 }
485 }
486
487 pair_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489 Some(pair_scores[0].0)
490 }
491
492 fn select_coverage_aware<G>(
494 &self,
495 candidates: &[Candidate<G>],
496 aggregator: &FitnessAggregator,
497 batch_size: usize,
498 min_evaluations: usize,
499 exploration_bonus: f64,
500 ) -> Vec<usize>
501 where
502 G: EvolutionaryGenome,
503 {
504 let mut scores: Vec<(usize, f64)> = candidates
505 .iter()
506 .enumerate()
507 .map(|(i, c)| {
508 let score = if c.evaluation_count < min_evaluations {
509 f64::MAX
511 } else {
512 let uncertainty = aggregator
514 .get_fitness_estimate(&c.id)
515 .map(|e| {
516 if e.variance.is_infinite() {
517 1e6
518 } else {
519 e.variance
520 }
521 })
522 .unwrap_or(1e6);
523
524 let bonus = exploration_bonus / (c.evaluation_count as f64 + 1.0);
526
527 uncertainty + bonus
528 };
529 (i, score)
530 })
531 .collect();
532
533 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
534
535 scores
536 .into_iter()
537 .take(batch_size)
538 .map(|(i, _)| i)
539 .collect()
540 }
541
542 fn find_informative_pair<G, R>(
544 &self,
545 candidates: &[Candidate<G>],
546 aggregator: &FitnessAggregator,
547 rng: &mut R,
548 ) -> usize
549 where
550 G: EvolutionaryGenome,
551 R: Rng,
552 {
553 let estimates: Vec<Option<FitnessEstimate>> = candidates
555 .iter()
556 .map(|c| aggregator.get_fitness_estimate(&c.id))
557 .collect();
558
559 let mut scores: Vec<(usize, f64)> = candidates
560 .iter()
561 .enumerate()
562 .map(|(i, _)| {
563 let score = estimates
564 .iter()
565 .enumerate()
566 .filter(|(j, _)| *j != i)
567 .map(|(_, other)| pairwise_entropy(estimates[i].as_ref(), other.as_ref()))
568 .sum::<f64>();
569 (i, score)
570 })
571 .collect();
572
573 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
574
575 let top_k = 3.min(scores.len());
577 let chosen = rng.gen_range(0..top_k);
578 scores[chosen].0
579 }
580}
581
582fn pairwise_entropy(a: Option<&FitnessEstimate>, b: Option<&FitnessEstimate>) -> f64 {
586 match (a, b) {
587 (Some(est_a), Some(est_b)) => {
588 let mean_diff = est_a.mean - est_b.mean;
590 let var_diff = est_a.variance + est_b.variance;
591
592 if var_diff.is_infinite() || var_diff <= 0.0 {
593 return 1.0;
595 }
596
597 let z = mean_diff / var_diff.sqrt();
599 let p = normal_cdf(z);
600
601 binary_entropy(p)
603 }
604 _ => 1.0, }
606}
607
608fn binary_entropy(p: f64) -> f64 {
610 let p = p.clamp(1e-10, 1.0 - 1e-10);
611 -(p * p.ln() + (1.0 - p) * (1.0 - p).ln())
612}
613
614fn normal_cdf(x: f64) -> f64 {
616 let a1 = 0.254829592;
618 let a2 = -0.284496736;
619 let a3 = 1.421413741;
620 let a4 = -1.453152027;
621 let a5 = 1.061405429;
622 let p = 0.3275911;
623
624 let sign = if x < 0.0 { -1.0 } else { 1.0 };
625 let x = x.abs() / std::f64::consts::SQRT_2;
626
627 let t = 1.0 / (1.0 + p * x);
628 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
629
630 0.5 * (1.0 + sign * y)
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use crate::genome::real_vector::RealVector;
637 use crate::interactive::aggregation::{AggregationModel, FitnessAggregator};
638 use crate::interactive::evaluator::CandidateId;
639
640 fn make_candidates(n: usize) -> Vec<Candidate<RealVector>> {
641 (0..n)
642 .map(|i| {
643 let mut c = Candidate::new(CandidateId(i), RealVector::new(vec![i as f64]));
644 c.evaluation_count = 0;
645 c
646 })
647 .collect()
648 }
649
650 #[test]
651 fn test_sequential_selection() {
652 let candidates = make_candidates(10);
653 let aggregator = FitnessAggregator::new(AggregationModel::default());
654 let mut rng = rand::thread_rng();
655
656 let strategy = SelectionStrategy::Sequential;
657 let selected = strategy.select_batch(&candidates, &aggregator, 3, &mut rng);
658
659 assert_eq!(selected.len(), 3);
660 assert!(selected.contains(&0));
662 assert!(selected.contains(&1));
663 assert!(selected.contains(&2));
664 }
665
666 #[test]
667 fn test_uncertainty_sampling() {
668 let mut candidates = make_candidates(5);
669 let mut aggregator = FitnessAggregator::new(AggregationModel::DirectRating {
670 default_rating: 5.0,
671 });
672 let mut rng = rand::thread_rng();
673
674 aggregator.record_rating(CandidateId(0), 7.0);
676 aggregator.record_rating(CandidateId(0), 7.0);
677 aggregator.record_rating(CandidateId(0), 7.0);
678 candidates[0].evaluation_count = 3;
679
680 aggregator.record_rating(CandidateId(1), 4.0);
682 aggregator.record_rating(CandidateId(1), 8.0);
683 candidates[1].evaluation_count = 2;
684
685 let strategy = SelectionStrategy::UncertaintySampling {
688 uncertainty_weight: 1.0,
689 };
690 let selected = strategy.select_batch(&candidates, &aggregator, 2, &mut rng);
691
692 assert_eq!(selected.len(), 2);
694 for &idx in &selected {
695 assert!(
696 idx != 0,
697 "Should not select the well-evaluated candidate with low variance"
698 );
699 }
700 }
701
702 #[test]
703 fn test_coverage_aware() {
704 let mut candidates = make_candidates(5);
705 candidates[0].evaluation_count = 3;
706 candidates[1].evaluation_count = 2;
707 candidates[2].evaluation_count = 0; candidates[3].evaluation_count = 0; candidates[4].evaluation_count = 1;
710
711 let aggregator = FitnessAggregator::new(AggregationModel::default());
712 let mut rng = rand::thread_rng();
713
714 let strategy = SelectionStrategy::CoverageAware {
715 min_evaluations: 2,
716 exploration_bonus: 1.0,
717 };
718 let selected = strategy.select_batch(&candidates, &aggregator, 2, &mut rng);
719
720 assert!(selected.contains(&2) || selected.contains(&3));
722 }
723
724 #[test]
725 fn test_select_pair_sequential() {
726 let candidates = make_candidates(5);
727 let aggregator = FitnessAggregator::new(AggregationModel::default());
728 let mut rng = rand::thread_rng();
729
730 let strategy = SelectionStrategy::Sequential;
731 let pair = strategy.select_pair(&candidates, &aggregator, &mut rng);
732
733 assert!(pair.is_some());
734 let (a, b) = pair.unwrap();
735 assert_ne!(a, b);
736 }
737
738 #[test]
739 fn test_select_pair_info_gain() {
740 let candidates = make_candidates(5);
741 let aggregator = FitnessAggregator::new(AggregationModel::default());
742 let mut rng = rand::thread_rng();
743
744 let strategy = SelectionStrategy::ExpectedInformationGain { temperature: 1.0 };
745 let pair = strategy.select_pair(&candidates, &aggregator, &mut rng);
746
747 assert!(pair.is_some());
748 let (a, b) = pair.unwrap();
749 assert_ne!(a, b);
750 }
751
752 #[test]
753 fn test_binary_entropy() {
754 let max_entropy = binary_entropy(0.5);
756 assert!((max_entropy - std::f64::consts::LN_2).abs() < 1e-6);
757
758 assert!(binary_entropy(0.001) < 0.1);
760 assert!(binary_entropy(0.999) < 0.1);
761 }
762
763 #[test]
764 fn test_normal_cdf() {
765 assert!((normal_cdf(0.0) - 0.5).abs() < 1e-6);
767
768 assert!(normal_cdf(-10.0) < 0.001);
770 assert!(normal_cdf(10.0) > 0.999);
771
772 assert!((normal_cdf(1.0) + normal_cdf(-1.0) - 1.0).abs() < 1e-6);
774 }
775
776 #[test]
777 fn test_empty_candidates() {
778 let candidates: Vec<Candidate<RealVector>> = vec![];
779 let aggregator = FitnessAggregator::new(AggregationModel::default());
780 let mut rng = rand::thread_rng();
781
782 let strategy = SelectionStrategy::default();
783 let selected = strategy.select_batch(&candidates, &aggregator, 5, &mut rng);
784 assert!(selected.is_empty());
785
786 let pair = strategy.select_pair(&candidates, &aggregator, &mut rng);
787 assert!(pair.is_none());
788 }
789
790 #[test]
791 fn test_single_candidate() {
792 let candidates = make_candidates(1);
793 let aggregator = FitnessAggregator::new(AggregationModel::default());
794 let mut rng = rand::thread_rng();
795
796 let strategy = SelectionStrategy::default();
797 let selected = strategy.select_batch(&candidates, &aggregator, 5, &mut rng);
798 assert_eq!(selected.len(), 1);
799
800 let pair = strategy.select_pair(&candidates, &aggregator, &mut rng);
801 assert!(pair.is_none()); }
803}