1use std::collections::{HashMap, HashSet};
6
7use rand::Rng;
8
9use crate::error::{OperatorError, OperatorResult};
10use crate::genome::bit_string::BitString;
11use crate::genome::bounds::MultiBounds;
12use crate::genome::permutation::Permutation;
13use crate::genome::real_vector::RealVector;
14use crate::genome::traits::{
15 BinaryGenome, EvolutionaryGenome, PermutationGenome, RealValuedGenome,
16};
17use crate::operators::traits::{BoundedCrossoverOperator, CrossoverOperator};
18
19#[derive(Clone, Debug)]
27pub struct SbxCrossover {
28 pub eta: f64,
31 pub crossover_probability: f64,
33}
34
35impl SbxCrossover {
36 pub fn new(eta: f64) -> Self {
38 assert!(eta >= 0.0, "Distribution index must be non-negative");
39 Self {
40 eta,
41 crossover_probability: 0.9,
42 }
43 }
44
45 pub fn with_probability(mut self, probability: f64) -> Self {
47 assert!(
48 (0.0..=1.0).contains(&probability),
49 "Probability must be in [0, 1]"
50 );
51 self.crossover_probability = probability;
52 self
53 }
54
55 fn spread_factor(&self, u: f64) -> f64 {
57 if u <= 0.5 {
58 (2.0 * u).powf(1.0 / (self.eta + 1.0))
59 } else {
60 (1.0 / (2.0 * (1.0 - u))).powf(1.0 / (self.eta + 1.0))
61 }
62 }
63
64 fn apply_sbx<R: Rng>(
66 &self,
67 parent1: &[f64],
68 parent2: &[f64],
69 bounds: Option<&MultiBounds>,
70 rng: &mut R,
71 ) -> (Vec<f64>, Vec<f64>) {
72 let mut child1: Vec<f64> = parent1.to_vec();
73 let mut child2: Vec<f64> = parent2.to_vec();
74
75 for i in 0..parent1.len() {
76 if rng.gen::<f64>() < self.crossover_probability {
77 let x1 = parent1[i];
78 let x2 = parent2[i];
79
80 if (x1 - x2).abs() > 1e-14 {
82 let u = rng.gen::<f64>();
83 let beta = self.spread_factor(u);
84
85 child1[i] = 0.5 * ((1.0 + beta) * x1 + (1.0 - beta) * x2);
86 child2[i] = 0.5 * ((1.0 - beta) * x1 + (1.0 + beta) * x2);
87
88 if let Some(b) = bounds {
90 if let Some(bound) = b.get(i) {
91 child1[i] = bound.clamp(child1[i]);
92 child2[i] = bound.clamp(child2[i]);
93 }
94 }
95 }
96 }
97 }
98
99 (child1, child2)
100 }
101}
102
103impl CrossoverOperator<RealVector> for SbxCrossover {
104 fn crossover<R: Rng>(
105 &self,
106 parent1: &RealVector,
107 parent2: &RealVector,
108 rng: &mut R,
109 ) -> OperatorResult<(RealVector, RealVector)> {
110 if parent1.dimension() != parent2.dimension() {
111 return OperatorResult::Failed(OperatorError::CrossoverFailed(
112 "Parent dimensions do not match".to_string(),
113 ));
114 }
115
116 let (child1_genes, child2_genes) =
117 self.apply_sbx(parent1.genes(), parent2.genes(), None, rng);
118
119 let child1 = RealVector::from_genes(child1_genes).unwrap();
120 let child2 = RealVector::from_genes(child2_genes).unwrap();
121
122 OperatorResult::Success((child1, child2))
123 }
124
125 fn crossover_probability(&self) -> f64 {
126 self.crossover_probability
127 }
128}
129
130impl BoundedCrossoverOperator<RealVector> for SbxCrossover {
131 fn crossover_bounded<R: Rng>(
132 &self,
133 parent1: &RealVector,
134 parent2: &RealVector,
135 bounds: &MultiBounds,
136 rng: &mut R,
137 ) -> OperatorResult<(RealVector, RealVector)> {
138 if parent1.dimension() != parent2.dimension() {
139 return OperatorResult::Failed(OperatorError::CrossoverFailed(
140 "Parent dimensions do not match".to_string(),
141 ));
142 }
143
144 let (child1_genes, child2_genes) =
145 self.apply_sbx(parent1.genes(), parent2.genes(), Some(bounds), rng);
146
147 let child1 = RealVector::from_genes(child1_genes).unwrap();
148 let child2 = RealVector::from_genes(child2_genes).unwrap();
149
150 OperatorResult::Success((child1, child2))
151 }
152}
153
154#[derive(Clone, Debug)]
158pub struct BlxAlphaCrossover {
159 pub alpha: f64,
161}
162
163impl BlxAlphaCrossover {
164 pub fn new(alpha: f64) -> Self {
166 assert!(alpha >= 0.0, "Alpha must be non-negative");
167 Self { alpha }
168 }
169
170 pub fn default_alpha() -> Self {
172 Self::new(0.5)
173 }
174}
175
176impl CrossoverOperator<RealVector> for BlxAlphaCrossover {
177 fn crossover<R: Rng>(
178 &self,
179 parent1: &RealVector,
180 parent2: &RealVector,
181 rng: &mut R,
182 ) -> OperatorResult<(RealVector, RealVector)> {
183 if parent1.dimension() != parent2.dimension() {
184 return OperatorResult::Failed(OperatorError::CrossoverFailed(
185 "Parent dimensions do not match".to_string(),
186 ));
187 }
188
189 let mut child1_genes = Vec::with_capacity(parent1.dimension());
190 let mut child2_genes = Vec::with_capacity(parent2.dimension());
191
192 for i in 0..parent1.dimension() {
193 let x1 = parent1[i];
194 let x2 = parent2[i];
195
196 let min_val = x1.min(x2);
197 let max_val = x1.max(x2);
198 let range = max_val - min_val;
199
200 let low = min_val - self.alpha * range;
201 let high = max_val + self.alpha * range;
202
203 child1_genes.push(rng.gen_range(low..=high));
204 child2_genes.push(rng.gen_range(low..=high));
205 }
206
207 let child1 = RealVector::from_genes(child1_genes).unwrap();
208 let child2 = RealVector::from_genes(child2_genes).unwrap();
209
210 OperatorResult::Success((child1, child2))
211 }
212}
213
214#[derive(Clone, Debug)]
218pub struct UniformCrossover {
219 pub bias: f64,
221}
222
223impl UniformCrossover {
224 pub fn new() -> Self {
226 Self { bias: 0.5 }
227 }
228
229 pub fn with_bias(bias: f64) -> Self {
231 assert!((0.0..=1.0).contains(&bias), "Bias must be in [0, 1]");
232 Self { bias }
233 }
234}
235
236impl Default for UniformCrossover {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl CrossoverOperator<BitString> for UniformCrossover {
243 fn crossover<R: Rng>(
244 &self,
245 parent1: &BitString,
246 parent2: &BitString,
247 rng: &mut R,
248 ) -> OperatorResult<(BitString, BitString)> {
249 if parent1.dimension() != parent2.dimension() {
250 return OperatorResult::Failed(OperatorError::CrossoverFailed(
251 "Parent dimensions do not match".to_string(),
252 ));
253 }
254
255 let mut child1_bits = Vec::with_capacity(parent1.dimension());
256 let mut child2_bits = Vec::with_capacity(parent2.dimension());
257
258 for i in 0..parent1.dimension() {
259 if rng.gen::<f64>() < self.bias {
260 child1_bits.push(parent1[i]);
261 child2_bits.push(parent2[i]);
262 } else {
263 child1_bits.push(parent2[i]);
264 child2_bits.push(parent1[i]);
265 }
266 }
267
268 let child1 = BitString::from_bits(child1_bits).unwrap();
269 let child2 = BitString::from_bits(child2_bits).unwrap();
270
271 OperatorResult::Success((child1, child2))
272 }
273}
274
275#[derive(Clone, Debug, Default)]
277pub struct OnePointCrossover;
278
279impl OnePointCrossover {
280 pub fn new() -> Self {
282 Self
283 }
284}
285
286impl CrossoverOperator<BitString> for OnePointCrossover {
287 fn crossover<R: Rng>(
288 &self,
289 parent1: &BitString,
290 parent2: &BitString,
291 rng: &mut R,
292 ) -> OperatorResult<(BitString, BitString)> {
293 if parent1.dimension() != parent2.dimension() {
294 return OperatorResult::Failed(OperatorError::CrossoverFailed(
295 "Parent dimensions do not match".to_string(),
296 ));
297 }
298
299 let n = parent1.dimension();
300 if n == 0 {
301 return OperatorResult::Success((parent1.clone(), parent2.clone()));
302 }
303
304 let crossover_point = rng.gen_range(0..n);
305
306 let mut child1_bits = Vec::with_capacity(n);
307 let mut child2_bits = Vec::with_capacity(n);
308
309 for i in 0..n {
310 if i < crossover_point {
311 child1_bits.push(parent1[i]);
312 child2_bits.push(parent2[i]);
313 } else {
314 child1_bits.push(parent2[i]);
315 child2_bits.push(parent1[i]);
316 }
317 }
318
319 let child1 = BitString::from_bits(child1_bits).unwrap();
320 let child2 = BitString::from_bits(child2_bits).unwrap();
321
322 OperatorResult::Success((child1, child2))
323 }
324}
325
326#[derive(Clone, Debug, Default)]
328pub struct TwoPointCrossover;
329
330impl TwoPointCrossover {
331 pub fn new() -> Self {
333 Self
334 }
335}
336
337impl CrossoverOperator<BitString> for TwoPointCrossover {
338 fn crossover<R: Rng>(
339 &self,
340 parent1: &BitString,
341 parent2: &BitString,
342 rng: &mut R,
343 ) -> OperatorResult<(BitString, BitString)> {
344 if parent1.dimension() != parent2.dimension() {
345 return OperatorResult::Failed(OperatorError::CrossoverFailed(
346 "Parent dimensions do not match".to_string(),
347 ));
348 }
349
350 let n = parent1.dimension();
351 if n < 2 {
352 return OperatorResult::Success((parent1.clone(), parent2.clone()));
353 }
354
355 let mut point1 = rng.gen_range(0..n);
356 let mut point2 = rng.gen_range(0..n);
357 if point1 > point2 {
358 std::mem::swap(&mut point1, &mut point2);
359 }
360
361 let mut child1_bits = Vec::with_capacity(n);
362 let mut child2_bits = Vec::with_capacity(n);
363
364 for i in 0..n {
365 if i < point1 || i >= point2 {
366 child1_bits.push(parent1[i]);
367 child2_bits.push(parent2[i]);
368 } else {
369 child1_bits.push(parent2[i]);
370 child2_bits.push(parent1[i]);
371 }
372 }
373
374 let child1 = BitString::from_bits(child1_bits).unwrap();
375 let child2 = BitString::from_bits(child2_bits).unwrap();
376
377 OperatorResult::Success((child1, child2))
378 }
379}
380
381#[derive(Clone, Debug)]
385pub struct ArithmeticCrossover {
386 pub weight: f64,
388}
389
390impl ArithmeticCrossover {
391 pub fn new(weight: f64) -> Self {
393 assert!((0.0..=1.0).contains(&weight), "Weight must be in [0, 1]");
394 Self { weight }
395 }
396
397 pub fn uniform() -> Self {
399 Self::new(0.5)
400 }
401}
402
403impl CrossoverOperator<RealVector> for ArithmeticCrossover {
404 fn crossover<R: Rng>(
405 &self,
406 parent1: &RealVector,
407 parent2: &RealVector,
408 _rng: &mut R,
409 ) -> OperatorResult<(RealVector, RealVector)> {
410 if parent1.dimension() != parent2.dimension() {
411 return OperatorResult::Failed(OperatorError::CrossoverFailed(
412 "Parent dimensions do not match".to_string(),
413 ));
414 }
415
416 let w = self.weight;
417 let mut child1_genes = Vec::with_capacity(parent1.dimension());
418 let mut child2_genes = Vec::with_capacity(parent2.dimension());
419
420 for i in 0..parent1.dimension() {
421 child1_genes.push(w * parent1[i] + (1.0 - w) * parent2[i]);
422 child2_genes.push((1.0 - w) * parent1[i] + w * parent2[i]);
423 }
424
425 let child1 = RealVector::from_genes(child1_genes).unwrap();
426 let child2 = RealVector::from_genes(child2_genes).unwrap();
427
428 OperatorResult::Success((child1, child2))
429 }
430}
431
432#[derive(Clone, Debug, Default)]
445pub struct PmxCrossover;
446
447impl PmxCrossover {
448 pub fn new() -> Self {
450 Self
451 }
452}
453
454impl CrossoverOperator<Permutation> for PmxCrossover {
455 fn crossover<R: Rng>(
456 &self,
457 parent1: &Permutation,
458 parent2: &Permutation,
459 rng: &mut R,
460 ) -> OperatorResult<(Permutation, Permutation)> {
461 let n = parent1.dimension();
462
463 if n != parent2.dimension() {
464 return OperatorResult::Failed(OperatorError::CrossoverFailed(
465 "Parent dimensions do not match".to_string(),
466 ));
467 }
468
469 if n < 2 {
470 return OperatorResult::Success((parent1.clone(), parent2.clone()));
471 }
472
473 let mut start = rng.gen_range(0..n);
475 let mut end = rng.gen_range(0..n);
476 if start > end {
477 std::mem::swap(&mut start, &mut end);
478 }
479
480 let p1 = parent1.permutation();
481 let p2 = parent2.permutation();
482
483 let mut child1 = vec![usize::MAX; n];
485 let mut child2 = vec![usize::MAX; n];
486
487 for i in start..=end {
489 child1[i] = p2[i];
490 child2[i] = p1[i];
491 }
492
493 let mut map1: HashMap<usize, usize> = HashMap::new();
495 let mut map2: HashMap<usize, usize> = HashMap::new();
496 for i in start..=end {
497 map1.insert(p2[i], p1[i]);
498 map2.insert(p1[i], p2[i]);
499 }
500
501 for i in (0..start).chain((end + 1)..n) {
503 let mut val1 = p1[i];
505 while child1[start..=end].contains(&val1) {
506 val1 = *map1.get(&val1).unwrap_or(&val1);
507 }
508 child1[i] = val1;
509
510 let mut val2 = p2[i];
512 while child2[start..=end].contains(&val2) {
513 val2 = *map2.get(&val2).unwrap_or(&val2);
514 }
515 child2[i] = val2;
516 }
517
518 let c1 = match Permutation::try_new(child1) {
519 Ok(p) => p,
520 Err(e) => {
521 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
522 "PMX produced invalid child1: {}",
523 e
524 )))
525 }
526 };
527 let c2 = match Permutation::try_new(child2) {
528 Ok(p) => p,
529 Err(e) => {
530 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
531 "PMX produced invalid child2: {}",
532 e
533 )))
534 }
535 };
536
537 OperatorResult::Success((c1, c2))
538 }
539}
540
541#[derive(Clone, Debug, Default)]
549pub struct OxCrossover;
550
551impl OxCrossover {
552 pub fn new() -> Self {
554 Self
555 }
556}
557
558impl CrossoverOperator<Permutation> for OxCrossover {
559 fn crossover<R: Rng>(
560 &self,
561 parent1: &Permutation,
562 parent2: &Permutation,
563 rng: &mut R,
564 ) -> OperatorResult<(Permutation, Permutation)> {
565 let n = parent1.dimension();
566
567 if n != parent2.dimension() {
568 return OperatorResult::Failed(OperatorError::CrossoverFailed(
569 "Parent dimensions do not match".to_string(),
570 ));
571 }
572
573 if n < 2 {
574 return OperatorResult::Success((parent1.clone(), parent2.clone()));
575 }
576
577 let mut start = rng.gen_range(0..n);
579 let mut end = rng.gen_range(0..n);
580 if start > end {
581 std::mem::swap(&mut start, &mut end);
582 }
583
584 let child1 = Self::ox_single(parent1, parent2, start, end);
585 let child2 = Self::ox_single(parent2, parent1, start, end);
586
587 let c1 = match Permutation::try_new(child1) {
588 Ok(p) => p,
589 Err(e) => {
590 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
591 "OX produced invalid child1: {}",
592 e
593 )))
594 }
595 };
596 let c2 = match Permutation::try_new(child2) {
597 Ok(p) => p,
598 Err(e) => {
599 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
600 "OX produced invalid child2: {}",
601 e
602 )))
603 }
604 };
605
606 OperatorResult::Success((c1, c2))
607 }
608}
609
610impl OxCrossover {
611 fn ox_single(
613 parent1: &Permutation,
614 parent2: &Permutation,
615 start: usize,
616 end: usize,
617 ) -> Vec<usize> {
618 let n = parent1.dimension();
619 let p1 = parent1.permutation();
620 let p2 = parent2.permutation();
621
622 let mut child = vec![usize::MAX; n];
623
624 let segment: HashSet<usize> = p1[start..=end].iter().copied().collect();
626 for i in start..=end {
627 child[i] = p1[i];
628 }
629
630 let mut pos = (end + 1) % n;
632 let mut p2_idx = (end + 1) % n;
633
634 while pos != start {
635 while segment.contains(&p2[p2_idx]) {
637 p2_idx = (p2_idx + 1) % n;
638 }
639
640 child[pos] = p2[p2_idx];
641 pos = (pos + 1) % n;
642 p2_idx = (p2_idx + 1) % n;
643 }
644
645 child
646 }
647}
648
649#[derive(Clone, Debug, Default)]
658pub struct CxCrossover;
659
660impl CxCrossover {
661 pub fn new() -> Self {
663 Self
664 }
665}
666
667impl CrossoverOperator<Permutation> for CxCrossover {
668 fn crossover<R: Rng>(
669 &self,
670 parent1: &Permutation,
671 parent2: &Permutation,
672 _rng: &mut R,
673 ) -> OperatorResult<(Permutation, Permutation)> {
674 let n = parent1.dimension();
675
676 if n != parent2.dimension() {
677 return OperatorResult::Failed(OperatorError::CrossoverFailed(
678 "Parent dimensions do not match".to_string(),
679 ));
680 }
681
682 if n == 0 {
683 return OperatorResult::Success((parent1.clone(), parent2.clone()));
684 }
685
686 let p1 = parent1.permutation();
687 let p2 = parent2.permutation();
688
689 let mut pos_in_p1: HashMap<usize, usize> = HashMap::new();
691 for (i, &val) in p1.iter().enumerate() {
692 pos_in_p1.insert(val, i);
693 }
694
695 let mut child1 = vec![usize::MAX; n];
697 let mut child2 = vec![usize::MAX; n];
698 let mut visited = vec![false; n];
699 let mut use_p1 = true; for start in 0..n {
702 if visited[start] {
703 continue;
704 }
705
706 let mut cycle_positions = Vec::new();
708 let mut pos = start;
709
710 loop {
711 cycle_positions.push(pos);
712 visited[pos] = true;
713
714 let val_in_p2 = p2[pos];
716 pos = *pos_in_p1.get(&val_in_p2).unwrap();
717
718 if pos == start {
719 break;
720 }
721 }
722
723 for &cycle_pos in &cycle_positions {
725 if use_p1 {
726 child1[cycle_pos] = p1[cycle_pos];
727 child2[cycle_pos] = p2[cycle_pos];
728 } else {
729 child1[cycle_pos] = p2[cycle_pos];
730 child2[cycle_pos] = p1[cycle_pos];
731 }
732 }
733
734 use_p1 = !use_p1; }
736
737 let c1 = match Permutation::try_new(child1) {
738 Ok(p) => p,
739 Err(e) => {
740 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
741 "CX produced invalid child1: {}",
742 e
743 )))
744 }
745 };
746 let c2 = match Permutation::try_new(child2) {
747 Ok(p) => p,
748 Err(e) => {
749 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
750 "CX produced invalid child2: {}",
751 e
752 )))
753 }
754 };
755
756 OperatorResult::Success((c1, c2))
757 }
758}
759
760#[derive(Clone, Debug, Default)]
765pub struct EdgeRecombinationCrossover;
766
767impl EdgeRecombinationCrossover {
768 pub fn new() -> Self {
770 Self
771 }
772
773 fn build_edge_table(
775 parent1: &Permutation,
776 parent2: &Permutation,
777 ) -> HashMap<usize, HashSet<usize>> {
778 let n = parent1.dimension();
779 let p1 = parent1.permutation();
780 let p2 = parent2.permutation();
781
782 let mut edges: HashMap<usize, HashSet<usize>> = HashMap::new();
783
784 for i in 0..n {
786 edges.insert(i, HashSet::new());
787 }
788
789 for i in 0..n {
791 let curr = p1[i];
792 let prev = p1[(i + n - 1) % n];
793 let next = p1[(i + 1) % n];
794 edges.get_mut(&curr).unwrap().insert(prev);
795 edges.get_mut(&curr).unwrap().insert(next);
796 }
797
798 for i in 0..n {
800 let curr = p2[i];
801 let prev = p2[(i + n - 1) % n];
802 let next = p2[(i + 1) % n];
803 edges.get_mut(&curr).unwrap().insert(prev);
804 edges.get_mut(&curr).unwrap().insert(next);
805 }
806
807 edges
808 }
809}
810
811impl CrossoverOperator<Permutation> for EdgeRecombinationCrossover {
812 fn crossover<R: Rng>(
813 &self,
814 parent1: &Permutation,
815 parent2: &Permutation,
816 rng: &mut R,
817 ) -> OperatorResult<(Permutation, Permutation)> {
818 let n = parent1.dimension();
819
820 if n != parent2.dimension() {
821 return OperatorResult::Failed(OperatorError::CrossoverFailed(
822 "Parent dimensions do not match".to_string(),
823 ));
824 }
825
826 if n < 2 {
827 return OperatorResult::Success((parent1.clone(), parent2.clone()));
828 }
829
830 let mut edges = Self::build_edge_table(parent1, parent2);
832
833 let mut child = Vec::with_capacity(n);
835 let mut remaining: HashSet<usize> = (0..n).collect();
836
837 let mut current = parent1.permutation()[0];
839 child.push(current);
840 remaining.remove(¤t);
841
842 for edge_set in edges.values_mut() {
844 edge_set.remove(¤t);
845 }
846
847 while child.len() < n {
848 let neighbors = edges.get(¤t).cloned().unwrap_or_default();
850
851 let next = if !neighbors.is_empty() {
853 let filtered: Vec<usize> = neighbors
854 .iter()
855 .filter(|x| remaining.contains(x))
856 .copied()
857 .collect();
858 if filtered.is_empty() {
859 let remaining_vec: Vec<usize> = remaining.iter().copied().collect();
861 remaining_vec[rng.gen_range(0..remaining_vec.len())]
862 } else {
863 *filtered
865 .iter()
866 .min_by_key(|&&x| edges.get(&x).map(|s| s.len()).unwrap_or(0))
867 .unwrap()
868 }
869 } else {
870 let remaining_vec: Vec<usize> = remaining.iter().copied().collect();
872 remaining_vec[rng.gen_range(0..remaining_vec.len())]
873 };
874
875 child.push(next);
876 remaining.remove(&next);
877 current = next;
878
879 for edge_set in edges.values_mut() {
881 edge_set.remove(¤t);
882 }
883 }
884
885 let mut edges2 = Self::build_edge_table(parent1, parent2);
887 let mut child2 = Vec::with_capacity(n);
888 let mut remaining2: HashSet<usize> = (0..n).collect();
889
890 let mut current2 = parent2.permutation()[0];
892 child2.push(current2);
893 remaining2.remove(¤t2);
894
895 for edge_set in edges2.values_mut() {
896 edge_set.remove(¤t2);
897 }
898
899 while child2.len() < n {
900 let neighbors = edges2.get(¤t2).cloned().unwrap_or_default();
901
902 let next2 = if !neighbors.is_empty() {
903 let filtered: Vec<usize> = neighbors
904 .iter()
905 .filter(|x| remaining2.contains(x))
906 .copied()
907 .collect();
908 if filtered.is_empty() {
909 let remaining_vec: Vec<usize> = remaining2.iter().copied().collect();
910 remaining_vec[rng.gen_range(0..remaining_vec.len())]
911 } else {
912 *filtered
913 .iter()
914 .min_by_key(|&&x| edges2.get(&x).map(|s| s.len()).unwrap_or(0))
915 .unwrap()
916 }
917 } else {
918 let remaining_vec: Vec<usize> = remaining2.iter().copied().collect();
919 remaining_vec[rng.gen_range(0..remaining_vec.len())]
920 };
921
922 child2.push(next2);
923 remaining2.remove(&next2);
924 current2 = next2;
925
926 for edge_set in edges2.values_mut() {
927 edge_set.remove(¤t2);
928 }
929 }
930
931 let c1 = match Permutation::try_new(child) {
932 Ok(p) => p,
933 Err(e) => {
934 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
935 "ERX produced invalid child1: {}",
936 e
937 )))
938 }
939 };
940 let c2 = match Permutation::try_new(child2) {
941 Ok(p) => p,
942 Err(e) => {
943 return OperatorResult::Failed(OperatorError::CrossoverFailed(format!(
944 "ERX produced invalid child2: {}",
945 e
946 )))
947 }
948 };
949
950 OperatorResult::Success((c1, c2))
951 }
952}
953
954use crate::genome::tree::{Function, Terminal, TreeGenome, TreeNode};
959
960#[derive(Clone, Debug)]
965pub struct SubtreeCrossover {
966 pub max_depth: Option<usize>,
968 pub function_probability: f64,
970}
971
972impl Default for SubtreeCrossover {
973 fn default() -> Self {
974 Self {
975 max_depth: Some(17), function_probability: 0.9, }
978 }
979}
980
981impl SubtreeCrossover {
982 pub fn new() -> Self {
984 Self::default()
985 }
986
987 pub fn with_max_depth(mut self, max_depth: usize) -> Self {
989 self.max_depth = Some(max_depth);
990 self
991 }
992
993 pub fn without_depth_limit(mut self) -> Self {
995 self.max_depth = None;
996 self
997 }
998
999 pub fn with_function_probability(mut self, prob: f64) -> Self {
1001 self.function_probability = prob.clamp(0.0, 1.0);
1002 self
1003 }
1004
1005 fn select_crossover_point<T: Terminal, F: Function, R: Rng>(
1007 &self,
1008 tree: &TreeNode<T, F>,
1009 rng: &mut R,
1010 ) -> Vec<usize> {
1011 let select_function = rng.gen::<f64>() < self.function_probability;
1013
1014 let positions = if select_function {
1015 let func_pos = tree.function_positions();
1016 if func_pos.is_empty() {
1017 tree.positions() } else {
1019 func_pos
1020 }
1021 } else {
1022 let term_pos = tree.terminal_positions();
1023 if term_pos.is_empty() {
1024 tree.positions()
1025 } else {
1026 term_pos
1027 }
1028 };
1029
1030 if positions.is_empty() {
1031 vec![] } else {
1033 positions[rng.gen_range(0..positions.len())].clone()
1034 }
1035 }
1036}
1037
1038impl<T: Terminal, F: Function> CrossoverOperator<TreeGenome<T, F>> for SubtreeCrossover {
1039 fn crossover<R: Rng>(
1040 &self,
1041 parent1: &TreeGenome<T, F>,
1042 parent2: &TreeGenome<T, F>,
1043 rng: &mut R,
1044 ) -> OperatorResult<(TreeGenome<T, F>, TreeGenome<T, F>)> {
1045 let point1 = self.select_crossover_point(&parent1.root, rng);
1047 let point2 = self.select_crossover_point(&parent2.root, rng);
1048
1049 let subtree1 = parent1
1051 .root
1052 .get_subtree(&point1)
1053 .cloned()
1054 .unwrap_or_else(|| parent1.root.clone());
1055 let subtree2 = parent2
1056 .root
1057 .get_subtree(&point2)
1058 .cloned()
1059 .unwrap_or_else(|| parent2.root.clone());
1060
1061 let mut child1_root = parent1.root.clone();
1063 let mut child2_root = parent2.root.clone();
1064
1065 if point1.is_empty() {
1067 child1_root = subtree2.clone();
1068 } else {
1069 child1_root.replace_subtree(&point1, subtree2.clone());
1070 }
1071
1072 if point2.is_empty() {
1073 child2_root = subtree1.clone();
1074 } else {
1075 child2_root.replace_subtree(&point2, subtree1);
1076 }
1077
1078 if let Some(max_depth) = self.max_depth {
1080 if child1_root.depth() > max_depth {
1081 return OperatorResult::Success((parent1.clone(), parent2.clone()));
1083 }
1084 if child2_root.depth() > max_depth {
1085 return OperatorResult::Success((parent1.clone(), parent2.clone()));
1086 }
1087 }
1088
1089 let child1 = TreeGenome::new(child1_root, parent1.max_depth);
1090 let child2 = TreeGenome::new(child2_root, parent2.max_depth);
1091
1092 OperatorResult::Success((child1, child2))
1093 }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098 use super::*;
1099 use approx::assert_relative_eq;
1100
1101 #[test]
1102 fn test_sbx_creates_valid_offspring() {
1103 let mut rng = rand::thread_rng();
1104 let parent1 = RealVector::new(vec![0.0, 0.0, 0.0]);
1105 let parent2 = RealVector::new(vec![1.0, 1.0, 1.0]);
1106
1107 let sbx = SbxCrossover::new(20.0);
1108 let result = sbx.crossover(&parent1, &parent2, &mut rng);
1109
1110 assert!(result.is_ok());
1111 let (child1, child2) = result.genome().unwrap();
1112 assert_eq!(child1.dimension(), 3);
1113 assert_eq!(child2.dimension(), 3);
1114 }
1115
1116 #[test]
1117 fn test_sbx_with_bounds() {
1118 let mut rng = rand::thread_rng();
1119 let parent1 = RealVector::new(vec![-0.3, -0.2]);
1121 let parent2 = RealVector::new(vec![0.3, 0.4]);
1122 let bounds = MultiBounds::symmetric(0.5, 2);
1123
1124 let sbx = SbxCrossover::new(2.0).with_probability(1.0); for _ in 0..100 {
1128 let result = sbx.crossover_bounded(&parent1, &parent2, &bounds, &mut rng);
1129 let (child1, child2) = result.genome().unwrap();
1130
1131 for gene in child1.genes() {
1132 assert!(
1133 *gene >= -0.5 && *gene <= 0.5,
1134 "gene {} out of bounds [-0.5, 0.5]",
1135 gene
1136 );
1137 }
1138 for gene in child2.genes() {
1139 assert!(
1140 *gene >= -0.5 && *gene <= 0.5,
1141 "gene {} out of bounds [-0.5, 0.5]",
1142 gene
1143 );
1144 }
1145 }
1146 }
1147
1148 #[test]
1149 fn test_sbx_spread_factor() {
1150 let sbx = SbxCrossover::new(20.0);
1151
1152 let beta = sbx.spread_factor(0.5);
1154 assert_relative_eq!(beta, 1.0, epsilon = 1e-10);
1155
1156 let beta_low = sbx.spread_factor(0.25);
1158 let beta_high = sbx.spread_factor(0.75);
1159 assert_relative_eq!(beta_low, 1.0 / beta_high, epsilon = 1e-10);
1160 }
1161
1162 #[test]
1163 fn test_sbx_identical_parents() {
1164 let mut rng = rand::thread_rng();
1165 let parent = RealVector::new(vec![1.0, 2.0, 3.0]);
1166
1167 let sbx = SbxCrossover::new(20.0);
1168 let result = sbx.crossover(&parent, &parent, &mut rng);
1169
1170 let (child1, child2) = result.genome().unwrap();
1171 assert_eq!(child1.genes(), parent.genes());
1173 assert_eq!(child2.genes(), parent.genes());
1174 }
1175
1176 #[test]
1177 fn test_sbx_dimension_mismatch() {
1178 let mut rng = rand::thread_rng();
1179 let parent1 = RealVector::new(vec![1.0, 2.0]);
1180 let parent2 = RealVector::new(vec![1.0, 2.0, 3.0]);
1181
1182 let sbx = SbxCrossover::new(20.0);
1183 let result = sbx.crossover(&parent1, &parent2, &mut rng);
1184
1185 assert!(!result.is_ok());
1186 }
1187
1188 #[test]
1189 fn test_blx_alpha_creates_valid_offspring() {
1190 let mut rng = rand::thread_rng();
1191 let parent1 = RealVector::new(vec![0.0, 0.0]);
1192 let parent2 = RealVector::new(vec![1.0, 1.0]);
1193
1194 let blx = BlxAlphaCrossover::new(0.5);
1195 let result = blx.crossover(&parent1, &parent2, &mut rng);
1196
1197 assert!(result.is_ok());
1198 let (child1, child2) = result.genome().unwrap();
1199 assert_eq!(child1.dimension(), 2);
1200 assert_eq!(child2.dimension(), 2);
1201 }
1202
1203 #[test]
1204 fn test_blx_alpha_range() {
1205 let mut rng = rand::thread_rng();
1206 let parent1 = RealVector::new(vec![0.0]);
1207 let parent2 = RealVector::new(vec![1.0]);
1208
1209 let blx = BlxAlphaCrossover::new(0.0); for _ in 0..100 {
1212 let result = blx.crossover(&parent1, &parent2, &mut rng);
1213 let (child1, child2) = result.genome().unwrap();
1214
1215 assert!(child1[0] >= 0.0 && child1[0] <= 1.0);
1217 assert!(child2[0] >= 0.0 && child2[0] <= 1.0);
1218 }
1219 }
1220
1221 #[test]
1222 fn test_uniform_crossover_creates_valid_offspring() {
1223 let mut rng = rand::thread_rng();
1224 let parent1 = BitString::new(vec![true, true, true, true]);
1225 let parent2 = BitString::new(vec![false, false, false, false]);
1226
1227 let ux = UniformCrossover::new();
1228 let result = ux.crossover(&parent1, &parent2, &mut rng);
1229
1230 assert!(result.is_ok());
1231 let (child1, child2) = result.genome().unwrap();
1232 assert_eq!(child1.len(), 4);
1233 assert_eq!(child2.len(), 4);
1234 }
1235
1236 #[test]
1237 fn test_uniform_crossover_complementary() {
1238 let mut rng = rand::thread_rng();
1239 let parent1 = BitString::new(vec![true, true, true, true]);
1240 let parent2 = BitString::new(vec![false, false, false, false]);
1241
1242 let ux = UniformCrossover::new();
1243 let result = ux.crossover(&parent1, &parent2, &mut rng);
1244
1245 let (child1, child2) = result.genome().unwrap();
1246
1247 for i in 0..4 {
1249 assert_ne!(child1[i], child2[i]);
1250 }
1251 }
1252
1253 #[test]
1254 fn test_one_point_crossover() {
1255 let mut rng = rand::thread_rng();
1256 let parent1 = BitString::ones(10);
1257 let parent2 = BitString::zeros(10);
1258
1259 let opx = OnePointCrossover::new();
1260 let result = opx.crossover(&parent1, &parent2, &mut rng);
1261
1262 assert!(result.is_ok());
1263 let (child1, child2) = result.genome().unwrap();
1264
1265 for i in 0..10 {
1268 assert_ne!(child1[i], child2[i]);
1269 }
1270 }
1271
1272 #[test]
1273 fn test_two_point_crossover() {
1274 let mut rng = rand::thread_rng();
1275 let parent1 = BitString::ones(10);
1276 let parent2 = BitString::zeros(10);
1277
1278 let tpx = TwoPointCrossover::new();
1279 let result = tpx.crossover(&parent1, &parent2, &mut rng);
1280
1281 assert!(result.is_ok());
1282 let (child1, child2) = result.genome().unwrap();
1283 assert_eq!(child1.len(), 10);
1284 assert_eq!(child2.len(), 10);
1285 }
1286
1287 #[test]
1288 fn test_arithmetic_crossover() {
1289 let mut rng = rand::thread_rng();
1290 let parent1 = RealVector::new(vec![0.0, 0.0]);
1291 let parent2 = RealVector::new(vec![1.0, 1.0]);
1292
1293 let ax = ArithmeticCrossover::new(0.5);
1294 let result = ax.crossover(&parent1, &parent2, &mut rng);
1295
1296 let (child1, child2) = result.genome().unwrap();
1297
1298 for gene in child1.genes() {
1300 assert_relative_eq!(*gene, 0.5);
1301 }
1302 for gene in child2.genes() {
1303 assert_relative_eq!(*gene, 0.5);
1304 }
1305 }
1306
1307 #[test]
1308 fn test_arithmetic_crossover_weighted() {
1309 let mut rng = rand::thread_rng();
1310 let parent1 = RealVector::new(vec![0.0]);
1311 let parent2 = RealVector::new(vec![1.0]);
1312
1313 let ax = ArithmeticCrossover::new(0.75);
1314 let result = ax.crossover(&parent1, &parent2, &mut rng);
1315
1316 let (child1, child2) = result.genome().unwrap();
1317
1318 assert_relative_eq!(child1[0], 0.25);
1321 assert_relative_eq!(child2[0], 0.75);
1322 }
1323
1324 #[test]
1329 fn test_pmx_creates_valid_permutations() {
1330 let mut rng = rand::thread_rng();
1331 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4, 5, 6, 7]);
1332 let parent2 = Permutation::new(vec![7, 6, 5, 4, 3, 2, 1, 0]);
1333
1334 let pmx = PmxCrossover::new();
1335
1336 for _ in 0..100 {
1337 let result = pmx.crossover(&parent1, &parent2, &mut rng);
1338 assert!(result.is_ok());
1339 let (child1, child2) = result.genome().unwrap();
1340
1341 assert!(child1.is_valid_permutation());
1342 assert!(child2.is_valid_permutation());
1343 assert_eq!(child1.dimension(), 8);
1344 assert_eq!(child2.dimension(), 8);
1345 }
1346 }
1347
1348 #[test]
1349 fn test_pmx_identical_parents() {
1350 let mut rng = rand::thread_rng();
1351 let parent = Permutation::new(vec![0, 1, 2, 3, 4]);
1352
1353 let pmx = PmxCrossover::new();
1354 let result = pmx.crossover(&parent, &parent, &mut rng);
1355
1356 let (child1, child2) = result.genome().unwrap();
1357 assert_eq!(child1.as_slice(), parent.as_slice());
1359 assert_eq!(child2.as_slice(), parent.as_slice());
1360 }
1361
1362 #[test]
1363 fn test_pmx_dimension_mismatch() {
1364 let mut rng = rand::thread_rng();
1365 let parent1 = Permutation::new(vec![0, 1, 2, 3]);
1366 let parent2 = Permutation::new(vec![0, 1, 2, 3, 4]);
1367
1368 let pmx = PmxCrossover::new();
1369 let result = pmx.crossover(&parent1, &parent2, &mut rng);
1370
1371 assert!(!result.is_ok());
1372 }
1373
1374 #[test]
1375 fn test_ox_creates_valid_permutations() {
1376 let mut rng = rand::thread_rng();
1377 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4, 5, 6, 7]);
1378 let parent2 = Permutation::new(vec![7, 6, 5, 4, 3, 2, 1, 0]);
1379
1380 let ox = OxCrossover::new();
1381
1382 for _ in 0..100 {
1383 let result = ox.crossover(&parent1, &parent2, &mut rng);
1384 assert!(result.is_ok());
1385 let (child1, child2) = result.genome().unwrap();
1386
1387 assert!(child1.is_valid_permutation());
1388 assert!(child2.is_valid_permutation());
1389 assert_eq!(child1.dimension(), 8);
1390 assert_eq!(child2.dimension(), 8);
1391 }
1392 }
1393
1394 #[test]
1395 fn test_ox_preserves_segment() {
1396 use rand::SeedableRng;
1397 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1398 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4, 5, 6, 7]);
1399 let parent2 = Permutation::new(vec![7, 6, 5, 4, 3, 2, 1, 0]);
1400
1401 let ox = OxCrossover::new();
1402 let result = ox.crossover(&parent1, &parent2, &mut rng);
1403
1404 let (child1, child2) = result.genome().unwrap();
1405 assert!(child1.is_valid_permutation());
1406 assert!(child2.is_valid_permutation());
1407 }
1408
1409 #[test]
1410 fn test_cx_creates_valid_permutations() {
1411 let mut rng = rand::thread_rng();
1412 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4, 5, 6, 7]);
1413 let parent2 = Permutation::new(vec![1, 2, 3, 4, 5, 6, 7, 0]);
1414
1415 let cx = CxCrossover::new();
1416
1417 for _ in 0..100 {
1418 let result = cx.crossover(&parent1, &parent2, &mut rng);
1419 assert!(result.is_ok());
1420 let (child1, child2) = result.genome().unwrap();
1421
1422 assert!(child1.is_valid_permutation());
1423 assert!(child2.is_valid_permutation());
1424 assert_eq!(child1.dimension(), 8);
1425 assert_eq!(child2.dimension(), 8);
1426 }
1427 }
1428
1429 #[test]
1430 fn test_cx_preserves_positions() {
1431 let mut rng = rand::thread_rng();
1432 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4]);
1434 let parent2 = Permutation::new(vec![4, 3, 2, 1, 0]);
1435
1436 let cx = CxCrossover::new();
1437 let result = cx.crossover(&parent1, &parent2, &mut rng);
1438
1439 let (child1, child2) = result.genome().unwrap();
1440
1441 for i in 0..5 {
1443 let c1_val = child1[i];
1444 let c2_val = child2[i];
1445
1446 assert!(c1_val == parent1[i] || c1_val == parent2[i]);
1447 assert!(c2_val == parent1[i] || c2_val == parent2[i]);
1448 }
1449 }
1450
1451 #[test]
1452 fn test_cx_identical_parents() {
1453 let mut rng = rand::thread_rng();
1454 let parent = Permutation::new(vec![0, 1, 2, 3, 4]);
1455
1456 let cx = CxCrossover::new();
1457 let result = cx.crossover(&parent, &parent, &mut rng);
1458
1459 let (child1, child2) = result.genome().unwrap();
1460 assert_eq!(child1.as_slice(), parent.as_slice());
1462 assert_eq!(child2.as_slice(), parent.as_slice());
1463 }
1464
1465 #[test]
1466 fn test_erx_creates_valid_permutations() {
1467 let mut rng = rand::thread_rng();
1468 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4, 5, 6, 7]);
1469 let parent2 = Permutation::new(vec![7, 6, 5, 4, 3, 2, 1, 0]);
1470
1471 let erx = EdgeRecombinationCrossover::new();
1472
1473 for _ in 0..50 {
1474 let result = erx.crossover(&parent1, &parent2, &mut rng);
1475 assert!(result.is_ok());
1476 let (child1, child2) = result.genome().unwrap();
1477
1478 assert!(child1.is_valid_permutation());
1479 assert!(child2.is_valid_permutation());
1480 assert_eq!(child1.dimension(), 8);
1481 assert_eq!(child2.dimension(), 8);
1482 }
1483 }
1484
1485 #[test]
1486 fn test_erx_preserves_some_edges() {
1487 use rand::SeedableRng;
1488 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
1489
1490 let parent1 = Permutation::new(vec![0, 1, 2, 3, 4]);
1492 let parent2 = Permutation::new(vec![0, 1, 4, 3, 2]);
1493
1494 let erx = EdgeRecombinationCrossover::new();
1496 let result = erx.crossover(&parent1, &parent2, &mut rng);
1497
1498 let (child1, _child2) = result.genome().unwrap();
1499 assert!(child1.is_valid_permutation());
1500 }
1501}