1use rand::seq::SliceRandom;
6use rand::Rng;
7use rand_distr::{Distribution, WeightedIndex};
8
9use crate::genome::traits::EvolutionaryGenome;
10use crate::operators::traits::SelectionOperator;
11
12#[derive(Clone, Debug)]
16pub struct TournamentSelection {
17 pub tournament_size: usize,
19}
20
21impl TournamentSelection {
22 pub fn new(tournament_size: usize) -> Self {
24 assert!(tournament_size >= 1, "Tournament size must be at least 1");
25 Self { tournament_size }
26 }
27
28 pub fn binary() -> Self {
30 Self::new(2)
31 }
32}
33
34impl<G: EvolutionaryGenome> SelectionOperator<G> for TournamentSelection {
35 fn select<R: Rng>(&self, population: &[(G, f64)], rng: &mut R) -> usize {
36 assert!(!population.is_empty(), "Population cannot be empty");
37
38 let tournament_size = self.tournament_size.min(population.len());
39
40 let indices: Vec<usize> = (0..population.len()).collect();
42 let tournament: Vec<usize> = indices
43 .choose_multiple(rng, tournament_size)
44 .copied()
45 .collect();
46
47 tournament
49 .into_iter()
50 .max_by(|&a, &b| {
51 population[a]
52 .1
53 .partial_cmp(&population[b].1)
54 .unwrap_or(std::cmp::Ordering::Equal)
55 })
56 .unwrap()
57 }
58}
59
60#[derive(Clone, Debug)]
64pub struct RouletteSelection {
65 offset: f64,
67}
68
69impl RouletteSelection {
70 pub fn new() -> Self {
72 Self { offset: 0.0 }
73 }
74
75 pub fn with_offset(offset: f64) -> Self {
77 Self { offset }
78 }
79}
80
81impl Default for RouletteSelection {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<G: EvolutionaryGenome> SelectionOperator<G> for RouletteSelection {
88 fn select<R: Rng>(&self, population: &[(G, f64)], rng: &mut R) -> usize {
89 assert!(!population.is_empty(), "Population cannot be empty");
90
91 let min_fitness = population
93 .iter()
94 .map(|(_, f)| *f)
95 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
96 .unwrap();
97
98 let offset = if min_fitness < 0.0 {
99 -min_fitness + self.offset + 1.0
100 } else {
101 self.offset
102 };
103
104 let weights: Vec<f64> = population.iter().map(|(_, f)| f + offset).collect();
106
107 let total: f64 = weights.iter().sum();
109 if total <= 0.0 {
110 return rng.gen_range(0..population.len());
111 }
112
113 match WeightedIndex::new(&weights) {
115 Ok(dist) => dist.sample(rng),
116 Err(_) => rng.gen_range(0..population.len()),
117 }
118 }
119}
120
121#[derive(Clone, Debug)]
125pub struct TruncationSelection {
126 pub truncation_ratio: f64,
128}
129
130impl TruncationSelection {
131 pub fn new(truncation_ratio: f64) -> Self {
133 assert!(
134 truncation_ratio > 0.0 && truncation_ratio <= 1.0,
135 "Truncation ratio must be in (0, 1]"
136 );
137 Self { truncation_ratio }
138 }
139}
140
141impl<G: EvolutionaryGenome> SelectionOperator<G> for TruncationSelection {
142 fn select<R: Rng>(&self, population: &[(G, f64)], rng: &mut R) -> usize {
143 assert!(!population.is_empty(), "Population cannot be empty");
144
145 let mut indices: Vec<usize> = (0..population.len()).collect();
147 indices.sort_by(|&a, &b| {
148 population[b]
149 .1
150 .partial_cmp(&population[a].1)
151 .unwrap_or(std::cmp::Ordering::Equal)
152 });
153
154 let cutoff = ((population.len() as f64) * self.truncation_ratio).ceil() as usize;
156 let cutoff = cutoff.max(1);
157
158 indices[rng.gen_range(0..cutoff)]
159 }
160}
161
162#[derive(Clone, Debug)]
166pub struct RankSelection {
167 pub selection_pressure: f64,
169}
170
171impl RankSelection {
172 pub fn new(selection_pressure: f64) -> Self {
174 assert!(
175 (1.0..=2.0).contains(&selection_pressure),
176 "Selection pressure must be in [1.0, 2.0]"
177 );
178 Self { selection_pressure }
179 }
180}
181
182impl Default for RankSelection {
183 fn default() -> Self {
184 Self::new(1.5)
185 }
186}
187
188impl<G: EvolutionaryGenome> SelectionOperator<G> for RankSelection {
189 fn select<R: Rng>(&self, population: &[(G, f64)], rng: &mut R) -> usize {
190 assert!(!population.is_empty(), "Population cannot be empty");
191
192 let n = population.len();
193 let sp = self.selection_pressure;
194
195 let mut indices: Vec<usize> = (0..n).collect();
197 indices.sort_by(|&a, &b| {
198 population[a]
199 .1
200 .partial_cmp(&population[b].1)
201 .unwrap_or(std::cmp::Ordering::Equal)
202 });
203
204 let weights: Vec<f64> = (0..n)
207 .map(|rank| {
208 if n == 1 {
209 1.0
210 } else {
211 2.0 - sp + 2.0 * (sp - 1.0) * (rank as f64) / ((n - 1) as f64)
212 }
213 })
214 .collect();
215
216 match WeightedIndex::new(&weights) {
217 Ok(dist) => indices[dist.sample(rng)],
218 Err(_) => indices[rng.gen_range(0..n)],
219 }
220 }
221}
222
223#[derive(Clone, Debug)]
227pub struct BoltzmannSelection {
228 pub temperature: f64,
230}
231
232impl BoltzmannSelection {
233 pub fn new(temperature: f64) -> Self {
235 assert!(temperature > 0.0, "Temperature must be positive");
236 Self { temperature }
237 }
238}
239
240impl<G: EvolutionaryGenome> SelectionOperator<G> for BoltzmannSelection {
241 fn select<R: Rng>(&self, population: &[(G, f64)], rng: &mut R) -> usize {
242 assert!(!population.is_empty(), "Population cannot be empty");
243
244 let scaled: Vec<f64> = population
246 .iter()
247 .map(|(_, f)| f / self.temperature)
248 .collect();
249 let max_scaled = scaled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
250
251 let weights: Vec<f64> = scaled.iter().map(|s| (s - max_scaled).exp()).collect();
252
253 match WeightedIndex::new(&weights) {
254 Ok(dist) => dist.sample(rng),
255 Err(_) => rng.gen_range(0..population.len()),
256 }
257 }
258}
259
260#[derive(Clone, Debug, Default)]
262pub struct RandomSelection;
263
264impl RandomSelection {
265 pub fn new() -> Self {
267 Self
268 }
269}
270
271impl<G: EvolutionaryGenome> SelectionOperator<G> for RandomSelection {
272 fn select<R: Rng>(&self, population: &[(G, f64)], rng: &mut R) -> usize {
273 assert!(!population.is_empty(), "Population cannot be empty");
274 rng.gen_range(0..population.len())
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::genome::real_vector::RealVector;
282
283 fn create_population(size: usize) -> Vec<(RealVector, f64)> {
284 (0..size)
285 .map(|i| (RealVector::new(vec![i as f64]), i as f64))
286 .collect()
287 }
288
289 #[test]
290 fn test_tournament_selection_selects_valid_index() {
291 let mut rng = rand::thread_rng();
292 let population = create_population(10);
293 let selection = TournamentSelection::new(3);
294
295 for _ in 0..100 {
296 let idx = selection.select(&population, &mut rng);
297 assert!(idx < population.len());
298 }
299 }
300
301 #[test]
302 fn test_tournament_selection_binary() {
303 let selection = TournamentSelection::binary();
304 assert_eq!(selection.tournament_size, 2);
305 }
306
307 #[test]
308 fn test_tournament_selection_prefers_fitter() {
309 let mut rng = rand::thread_rng();
310 let population: Vec<(RealVector, f64)> = vec![
312 (RealVector::new(vec![0.0]), 0.0),
313 (RealVector::new(vec![1.0]), 100.0), (RealVector::new(vec![2.0]), 0.0),
315 ];
316
317 let selection = TournamentSelection::new(3); let mut best_count = 0;
320 let trials = 100;
321 for _ in 0..trials {
322 let idx = selection.select(&population, &mut rng);
323 if idx == 1 {
324 best_count += 1;
325 }
326 }
327
328 assert_eq!(best_count, trials);
330 }
331
332 #[test]
333 fn test_roulette_selection_selects_valid_index() {
334 let mut rng = rand::thread_rng();
335 let population = create_population(10);
336 let selection = RouletteSelection::new();
337
338 for _ in 0..100 {
339 let idx = selection.select(&population, &mut rng);
340 assert!(idx < population.len());
341 }
342 }
343
344 #[test]
345 fn test_roulette_selection_handles_negative_fitness() {
346 let mut rng = rand::thread_rng();
347 let population: Vec<(RealVector, f64)> = vec![
348 (RealVector::new(vec![0.0]), -10.0),
349 (RealVector::new(vec![1.0]), -5.0),
350 (RealVector::new(vec![2.0]), -1.0),
351 ];
352
353 let selection = RouletteSelection::new();
354
355 for _ in 0..100 {
356 let idx = selection.select(&population, &mut rng);
357 assert!(idx < population.len());
358 }
359 }
360
361 #[test]
362 fn test_truncation_selection_selects_from_top() {
363 let mut rng = rand::thread_rng();
364 let population = create_population(10);
365 let selection = TruncationSelection::new(0.2); for _ in 0..100 {
368 let idx = selection.select(&population, &mut rng);
369 assert!(idx >= 8);
371 }
372 }
373
374 #[test]
375 fn test_rank_selection_selects_valid_index() {
376 let mut rng = rand::thread_rng();
377 let population = create_population(10);
378 let selection = RankSelection::new(1.5);
379
380 for _ in 0..100 {
381 let idx = selection.select(&population, &mut rng);
382 assert!(idx < population.len());
383 }
384 }
385
386 #[test]
387 fn test_boltzmann_selection_selects_valid_index() {
388 let mut rng = rand::thread_rng();
389 let population = create_population(10);
390 let selection = BoltzmannSelection::new(1.0);
391
392 for _ in 0..100 {
393 let idx = selection.select(&population, &mut rng);
394 assert!(idx < population.len());
395 }
396 }
397
398 #[test]
399 fn test_boltzmann_selection_temperature_effect() {
400 let mut rng = rand::thread_rng();
401 let population: Vec<(RealVector, f64)> = vec![
403 (RealVector::new(vec![0.0]), 0.0),
404 (RealVector::new(vec![1.0]), 10.0),
405 ];
406
407 let low_temp = BoltzmannSelection::new(0.1);
409 let high_temp = BoltzmannSelection::new(100.0);
411
412 let mut low_best_count = 0;
413 let mut high_best_count = 0;
414 let trials = 1000;
415
416 for _ in 0..trials {
417 if low_temp.select(&population, &mut rng) == 1 {
418 low_best_count += 1;
419 }
420 if high_temp.select(&population, &mut rng) == 1 {
421 high_best_count += 1;
422 }
423 }
424
425 assert!(low_best_count > high_best_count);
427 }
428
429 #[test]
430 fn test_random_selection_uniform() {
431 let mut rng = rand::thread_rng();
432 let population = create_population(2);
433 let selection = RandomSelection::new();
434
435 let mut counts = [0, 0];
436 let trials = 1000;
437
438 for _ in 0..trials {
439 counts[selection.select(&population, &mut rng)] += 1;
440 }
441
442 let ratio = counts[0] as f64 / counts[1] as f64;
444 assert!(ratio > 0.8 && ratio < 1.2);
445 }
446
447 #[test]
448 fn test_select_many() {
449 let mut rng = rand::thread_rng();
450 let population = create_population(10);
451 let selection = TournamentSelection::new(3);
452
453 let indices = selection.select_many(&population, 5, &mut rng);
454 assert_eq!(indices.len(), 5);
455 for idx in indices {
456 assert!(idx < population.len());
457 }
458 }
459
460 #[test]
461 #[should_panic(expected = "Tournament size must be at least 1")]
462 fn test_tournament_size_zero() {
463 TournamentSelection::new(0);
464 }
465
466 #[test]
467 #[should_panic(expected = "Truncation ratio must be in (0, 1]")]
468 fn test_truncation_ratio_zero() {
469 TruncationSelection::new(0.0);
470 }
471
472 #[test]
473 #[should_panic(expected = "Temperature must be positive")]
474 fn test_boltzmann_temperature_zero() {
475 BoltzmannSelection::new(0.0);
476 }
477}