Skip to main content

fugue_evo/diagnostics/
convergence.rs

1//! Convergence detection for evolutionary algorithms
2//!
3//! This module provides various methods to detect when an evolutionary algorithm
4//! has converged or should terminate.
5
6use serde::{Deserialize, Serialize};
7
8/// Result of a convergence check
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub enum ConvergenceStatus {
11    /// Algorithm has not converged
12    NotConverged,
13    /// Algorithm has converged with a reason
14    Converged(ConvergenceReason),
15}
16
17impl ConvergenceStatus {
18    /// Check if converged
19    pub fn is_converged(&self) -> bool {
20        matches!(self, Self::Converged(_))
21    }
22}
23
24/// Reason for convergence
25#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ConvergenceReason {
27    /// Fitness has not improved for many generations
28    FitnessStagnation { generations: usize },
29    /// Population diversity is below threshold
30    LowDiversity { diversity: u64 }, // Store as bits for Eq
31    /// Target fitness reached
32    TargetReached { target: u64 }, // Store as bits for Eq
33    /// Maximum generations reached
34    MaxGenerations { generations: usize },
35    /// Maximum evaluations reached
36    MaxEvaluations { evaluations: usize },
37    /// R-hat statistic indicates convergence
38    RhatConverged { rhat: u64 }, // Store as bits for Eq
39    /// Multiple criteria satisfied
40    MultipleReasons(Vec<ConvergenceReason>),
41    /// Custom termination
42    Custom(String),
43}
44
45impl ConvergenceReason {
46    /// Create a fitness stagnation reason
47    pub fn fitness_stagnation(generations: usize) -> Self {
48        Self::FitnessStagnation { generations }
49    }
50
51    /// Create a low diversity reason
52    pub fn low_diversity(diversity: f64) -> Self {
53        Self::LowDiversity {
54            diversity: diversity.to_bits(),
55        }
56    }
57
58    /// Create a target reached reason
59    pub fn target_reached(target: f64) -> Self {
60        Self::TargetReached {
61            target: target.to_bits(),
62        }
63    }
64
65    /// Create an R-hat converged reason
66    pub fn rhat_converged(rhat: f64) -> Self {
67        Self::RhatConverged {
68            rhat: rhat.to_bits(),
69        }
70    }
71}
72
73/// Configuration for convergence detection
74#[derive(Clone, Debug, Serialize, Deserialize)]
75pub struct ConvergenceConfig {
76    /// Maximum generations before termination
77    pub max_generations: Option<usize>,
78    /// Maximum fitness evaluations before termination
79    pub max_evaluations: Option<usize>,
80    /// Target fitness to reach
81    pub target_fitness: Option<f64>,
82    /// Tolerance for target fitness comparison
83    pub target_tolerance: f64,
84    /// Number of generations without improvement before stagnation
85    pub stagnation_generations: usize,
86    /// Minimum improvement to not count as stagnation
87    pub stagnation_threshold: f64,
88    /// Diversity threshold below which convergence is detected
89    pub diversity_threshold: f64,
90    /// R-hat threshold for convergence (typically 1.1)
91    pub rhat_threshold: f64,
92    /// Whether to use R-hat based convergence
93    pub use_rhat: bool,
94}
95
96impl Default for ConvergenceConfig {
97    fn default() -> Self {
98        Self {
99            max_generations: None,
100            max_evaluations: None,
101            target_fitness: None,
102            target_tolerance: 1e-6,
103            stagnation_generations: 50,
104            stagnation_threshold: 1e-9,
105            diversity_threshold: 0.01,
106            rhat_threshold: 1.1,
107            use_rhat: false,
108        }
109    }
110}
111
112impl ConvergenceConfig {
113    /// Create a new config with max generations
114    pub fn with_max_generations(generations: usize) -> Self {
115        Self {
116            max_generations: Some(generations),
117            ..Default::default()
118        }
119    }
120
121    /// Set max generations
122    pub fn max_generations(mut self, generations: usize) -> Self {
123        self.max_generations = Some(generations);
124        self
125    }
126
127    /// Set max evaluations
128    pub fn max_evaluations(mut self, evaluations: usize) -> Self {
129        self.max_evaluations = Some(evaluations);
130        self
131    }
132
133    /// Set target fitness
134    pub fn target_fitness(mut self, target: f64) -> Self {
135        self.target_fitness = Some(target);
136        self
137    }
138
139    /// Set target tolerance
140    pub fn target_tolerance(mut self, tolerance: f64) -> Self {
141        self.target_tolerance = tolerance;
142        self
143    }
144
145    /// Set stagnation detection parameters
146    pub fn stagnation(mut self, generations: usize, threshold: f64) -> Self {
147        self.stagnation_generations = generations;
148        self.stagnation_threshold = threshold;
149        self
150    }
151
152    /// Set diversity threshold
153    pub fn diversity_threshold(mut self, threshold: f64) -> Self {
154        self.diversity_threshold = threshold;
155        self
156    }
157
158    /// Enable R-hat based convergence
159    pub fn with_rhat(mut self, threshold: f64) -> Self {
160        self.use_rhat = true;
161        self.rhat_threshold = threshold;
162        self
163    }
164}
165
166/// Convergence detector that tracks evolution state
167#[derive(Clone, Debug)]
168pub struct ConvergenceDetector {
169    /// Configuration
170    config: ConvergenceConfig,
171    /// History of best fitness values
172    best_fitness_history: Vec<f64>,
173    /// History of mean fitness values (for R-hat)
174    mean_fitness_history: Vec<f64>,
175    /// History of diversity values
176    diversity_history: Vec<f64>,
177    /// Current generation
178    current_generation: usize,
179    /// Current evaluations
180    current_evaluations: usize,
181    /// Best fitness seen so far
182    best_fitness_overall: f64,
183    /// Generation when best fitness was last improved
184    last_improvement_generation: usize,
185}
186
187impl ConvergenceDetector {
188    /// Create a new convergence detector
189    pub fn new(config: ConvergenceConfig) -> Self {
190        Self {
191            config,
192            best_fitness_history: Vec::new(),
193            mean_fitness_history: Vec::new(),
194            diversity_history: Vec::new(),
195            current_generation: 0,
196            current_evaluations: 0,
197            best_fitness_overall: f64::NEG_INFINITY,
198            last_improvement_generation: 0,
199        }
200    }
201
202    /// Create with default config
203    pub fn with_defaults() -> Self {
204        Self::new(ConvergenceConfig::default())
205    }
206
207    /// Update with generation statistics
208    pub fn update(
209        &mut self,
210        generation: usize,
211        evaluations: usize,
212        best_fitness: f64,
213        mean_fitness: f64,
214        diversity: f64,
215    ) {
216        self.current_generation = generation;
217        self.current_evaluations = evaluations;
218        self.best_fitness_history.push(best_fitness);
219        self.mean_fitness_history.push(mean_fitness);
220        self.diversity_history.push(diversity);
221
222        // Track improvement
223        if best_fitness > self.best_fitness_overall + self.config.stagnation_threshold {
224            self.best_fitness_overall = best_fitness;
225            self.last_improvement_generation = generation;
226        }
227    }
228
229    /// Check if algorithm has converged
230    pub fn check(&self) -> ConvergenceStatus {
231        let mut reasons = Vec::new();
232
233        // Check max generations
234        if let Some(max_gen) = self.config.max_generations {
235            if self.current_generation >= max_gen {
236                reasons.push(ConvergenceReason::MaxGenerations {
237                    generations: self.current_generation,
238                });
239            }
240        }
241
242        // Check max evaluations
243        if let Some(max_eval) = self.config.max_evaluations {
244            if self.current_evaluations >= max_eval {
245                reasons.push(ConvergenceReason::MaxEvaluations {
246                    evaluations: self.current_evaluations,
247                });
248            }
249        }
250
251        // Check target fitness
252        if let Some(target) = self.config.target_fitness {
253            if let Some(&best) = self.best_fitness_history.last() {
254                if (best - target).abs() <= self.config.target_tolerance || best >= target {
255                    reasons.push(ConvergenceReason::target_reached(best));
256                }
257            }
258        }
259
260        // Check stagnation
261        let generations_since_improvement =
262            self.current_generation - self.last_improvement_generation;
263        if generations_since_improvement >= self.config.stagnation_generations {
264            reasons.push(ConvergenceReason::fitness_stagnation(
265                generations_since_improvement,
266            ));
267        }
268
269        // Check diversity
270        if let Some(&diversity) = self.diversity_history.last() {
271            if diversity < self.config.diversity_threshold {
272                reasons.push(ConvergenceReason::low_diversity(diversity));
273            }
274        }
275
276        // Check R-hat if enabled
277        if self.config.use_rhat && self.mean_fitness_history.len() >= 10 {
278            // Split history into "chains" for R-hat calculation
279            let rhat = self.compute_rhat();
280            if rhat < self.config.rhat_threshold {
281                reasons.push(ConvergenceReason::rhat_converged(rhat));
282            }
283        }
284
285        // Return result
286        match reasons.len() {
287            0 => ConvergenceStatus::NotConverged,
288            1 => ConvergenceStatus::Converged(reasons.pop().unwrap()),
289            _ => ConvergenceStatus::Converged(ConvergenceReason::MultipleReasons(reasons)),
290        }
291    }
292
293    /// Compute R-hat statistic from fitness history
294    fn compute_rhat(&self) -> f64 {
295        // Split history into two "chains"
296        let n = self.mean_fitness_history.len();
297        let half = n / 2;
298
299        if half < 5 {
300            return f64::INFINITY; // Not enough data
301        }
302
303        let chain1: Vec<f64> = self.mean_fitness_history[..half].to_vec();
304        let chain2: Vec<f64> = self.mean_fitness_history[half..].to_vec();
305
306        evolutionary_rhat(&[chain1, chain2])
307    }
308
309    /// Get the best fitness seen
310    pub fn best_fitness(&self) -> f64 {
311        self.best_fitness_overall
312    }
313
314    /// Get generations since last improvement
315    pub fn generations_without_improvement(&self) -> usize {
316        self.current_generation - self.last_improvement_generation
317    }
318
319    /// Get the latest diversity value
320    pub fn current_diversity(&self) -> Option<f64> {
321        self.diversity_history.last().copied()
322    }
323
324    /// Get the fitness history
325    pub fn fitness_history(&self) -> &[f64] {
326        &self.best_fitness_history
327    }
328
329    /// Get the diversity history
330    pub fn diversity_history(&self) -> &[f64] {
331        &self.diversity_history
332    }
333
334    /// Reset the detector
335    pub fn reset(&mut self) {
336        self.best_fitness_history.clear();
337        self.mean_fitness_history.clear();
338        self.diversity_history.clear();
339        self.current_generation = 0;
340        self.current_evaluations = 0;
341        self.best_fitness_overall = f64::NEG_INFINITY;
342        self.last_improvement_generation = 0;
343    }
344}
345
346/// R-hat analog for evolutionary convergence
347///
348/// Compares fitness distributions across multiple runs/chains.
349/// Values close to 1.0 indicate convergence.
350pub fn evolutionary_rhat(runs: &[Vec<f64>]) -> f64 {
351    if runs.is_empty() || runs[0].is_empty() {
352        return f64::INFINITY;
353    }
354
355    let m = runs.len() as f64;
356    let n = runs.iter().map(|r| r.len()).min().unwrap_or(0) as f64;
357
358    if n < 2.0 || m < 2.0 {
359        return f64::INFINITY;
360    }
361
362    // Between-chain variance
363    let chain_means: Vec<f64> = runs.iter().map(|r| r.iter().sum::<f64>() / n).collect();
364    let grand_mean = chain_means.iter().sum::<f64>() / m;
365    let b = n / (m - 1.0)
366        * chain_means
367            .iter()
368            .map(|cm| (cm - grand_mean).powi(2))
369            .sum::<f64>();
370
371    // Within-chain variance
372    let w: f64 = runs
373        .iter()
374        .map(|r| {
375            let mean = r.iter().sum::<f64>() / n;
376            r.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0)
377        })
378        .sum::<f64>()
379        / m;
380
381    if w == 0.0 {
382        return 1.0; // Perfect convergence
383    }
384
385    // Pooled variance estimate
386    let var_plus = ((n - 1.0) / n) * w + b / n;
387
388    (var_plus / w).sqrt()
389}
390
391/// Effective Sample Size (ESS) for evolutionary SMC
392///
393/// Measures the effective number of independent samples based on importance weights.
394pub fn evolutionary_ess(weights: &[f64]) -> f64 {
395    if weights.is_empty() {
396        return 0.0;
397    }
398
399    // Normalize weights
400    let sum: f64 = weights.iter().sum();
401    if sum == 0.0 {
402        return weights.len() as f64;
403    }
404
405    let normalized: Vec<f64> = weights.iter().map(|w| w / sum).collect();
406    let sum_sq: f64 = normalized.iter().map(|w| w * w).sum();
407
408    if sum_sq == 0.0 {
409        weights.len() as f64
410    } else {
411        1.0 / sum_sq
412    }
413}
414
415/// Effective Sample Size from log weights
416pub fn evolutionary_ess_log(log_weights: &[f64]) -> f64 {
417    if log_weights.is_empty() {
418        return 0.0;
419    }
420
421    // Use log-sum-exp trick for numerical stability
422    let max_log = log_weights
423        .iter()
424        .cloned()
425        .fold(f64::NEG_INFINITY, f64::max);
426
427    if max_log.is_infinite() {
428        return log_weights.len() as f64;
429    }
430
431    let weights: Vec<f64> = log_weights.iter().map(|lw| (lw - max_log).exp()).collect();
432    evolutionary_ess(&weights)
433}
434
435/// Detect fitness stagnation in a history of fitness values
436///
437/// Returns the number of generations the fitness has been stagnant.
438pub fn detect_stagnation(fitness_history: &[f64], threshold: f64) -> usize {
439    if fitness_history.len() < 2 {
440        return 0;
441    }
442
443    let best = fitness_history
444        .iter()
445        .cloned()
446        .fold(f64::NEG_INFINITY, f64::max);
447
448    // Count generations since best was improved
449    let mut stagnant_count: usize = 0;
450    for &fitness in fitness_history.iter().rev() {
451        if (fitness - best).abs() <= threshold {
452            stagnant_count += 1;
453        } else {
454            break;
455        }
456    }
457
458    stagnant_count.saturating_sub(1) // Don't count the best itself
459}
460
461/// Compute population convergence from fitness values
462///
463/// Returns a value between 0 (no convergence) and 1 (perfect convergence)
464/// based on the coefficient of variation of fitness values.
465pub fn fitness_convergence(fitness_values: &[f64]) -> f64 {
466    if fitness_values.len() < 2 {
467        return 1.0;
468    }
469
470    let mean = fitness_values.iter().sum::<f64>() / fitness_values.len() as f64;
471    if mean.abs() < f64::EPSILON {
472        return 1.0;
473    }
474
475    let variance = fitness_values
476        .iter()
477        .map(|f| (f - mean).powi(2))
478        .sum::<f64>()
479        / (fitness_values.len() - 1) as f64;
480    let std = variance.sqrt();
481
482    // Coefficient of variation (CV)
483    let cv = std / mean.abs();
484
485    // Convert to convergence metric (higher = more converged)
486    // CV of 0 means perfect convergence
487    // Use exponential decay so that small CV gives high convergence
488    (-cv).exp()
489}
490
491/// Termination criteria for evolutionary algorithms
492#[derive(Clone, Debug)]
493pub struct TerminationCriteria {
494    criteria: Vec<TerminationCriterion>,
495    require_all: bool,
496}
497
498/// A single termination criterion
499#[derive(Clone, Debug)]
500pub enum TerminationCriterion {
501    /// Maximum generations
502    MaxGenerations(usize),
503    /// Maximum evaluations
504    MaxEvaluations(usize),
505    /// Target fitness (maximize)
506    TargetFitness(f64, f64), // (target, tolerance)
507    /// Fitness stagnation
508    Stagnation(usize, f64), // (generations, threshold)
509    /// Diversity threshold
510    DiversityThreshold(f64),
511    /// Time limit in seconds
512    TimeLimit(f64),
513    /// Custom predicate
514    Custom(String), // Description only, evaluation handled externally
515}
516
517impl TerminationCriteria {
518    /// Create new empty criteria (any criterion triggers termination)
519    pub fn new() -> Self {
520        Self {
521            criteria: Vec::new(),
522            require_all: false,
523        }
524    }
525
526    /// Create criteria where all must be satisfied
527    pub fn require_all() -> Self {
528        Self {
529            criteria: Vec::new(),
530            require_all: true,
531        }
532    }
533
534    /// Add a criterion
535    pub fn add(mut self, criterion: TerminationCriterion) -> Self {
536        self.criteria.push(criterion);
537        self
538    }
539
540    /// Add max generations criterion
541    pub fn max_generations(self, generations: usize) -> Self {
542        self.add(TerminationCriterion::MaxGenerations(generations))
543    }
544
545    /// Add max evaluations criterion
546    pub fn max_evaluations(self, evaluations: usize) -> Self {
547        self.add(TerminationCriterion::MaxEvaluations(evaluations))
548    }
549
550    /// Add target fitness criterion
551    pub fn target_fitness(self, target: f64, tolerance: f64) -> Self {
552        self.add(TerminationCriterion::TargetFitness(target, tolerance))
553    }
554
555    /// Add stagnation criterion
556    pub fn stagnation(self, generations: usize, threshold: f64) -> Self {
557        self.add(TerminationCriterion::Stagnation(generations, threshold))
558    }
559
560    /// Add diversity threshold criterion
561    pub fn diversity_threshold(self, threshold: f64) -> Self {
562        self.add(TerminationCriterion::DiversityThreshold(threshold))
563    }
564
565    /// Add time limit criterion
566    pub fn time_limit(self, seconds: f64) -> Self {
567        self.add(TerminationCriterion::TimeLimit(seconds))
568    }
569
570    /// Check if termination criteria are met
571    pub fn should_terminate(
572        &self,
573        generation: usize,
574        evaluations: usize,
575        best_fitness: f64,
576        diversity: f64,
577        stagnation_generations: usize,
578        elapsed_seconds: f64,
579    ) -> Option<ConvergenceReason> {
580        let mut satisfied = Vec::new();
581
582        for criterion in &self.criteria {
583            let met = match criterion {
584                TerminationCriterion::MaxGenerations(max) => generation >= *max,
585                TerminationCriterion::MaxEvaluations(max) => evaluations >= *max,
586                TerminationCriterion::TargetFitness(target, tolerance) => {
587                    (best_fitness - target).abs() <= *tolerance || best_fitness >= *target
588                }
589                TerminationCriterion::Stagnation(gens, _threshold) => {
590                    stagnation_generations >= *gens
591                }
592                TerminationCriterion::DiversityThreshold(thresh) => diversity < *thresh,
593                TerminationCriterion::TimeLimit(limit) => elapsed_seconds >= *limit,
594                TerminationCriterion::Custom(_) => false, // Handled externally
595            };
596
597            if met {
598                satisfied.push(criterion.to_reason(
599                    generation,
600                    evaluations,
601                    best_fitness,
602                    diversity,
603                ));
604            }
605        }
606
607        if satisfied.is_empty() {
608            return None;
609        }
610
611        if self.require_all && satisfied.len() < self.criteria.len() {
612            return None;
613        }
614
615        // Return the reason(s)
616        if satisfied.len() == 1 {
617            Some(satisfied.pop().unwrap())
618        } else {
619            Some(ConvergenceReason::MultipleReasons(satisfied))
620        }
621    }
622
623    /// Get all criteria
624    pub fn criteria(&self) -> &[TerminationCriterion] {
625        &self.criteria
626    }
627}
628
629impl Default for TerminationCriteria {
630    fn default() -> Self {
631        Self::new()
632    }
633}
634
635impl TerminationCriterion {
636    fn to_reason(
637        &self,
638        generation: usize,
639        evaluations: usize,
640        best_fitness: f64,
641        diversity: f64,
642    ) -> ConvergenceReason {
643        match self {
644            Self::MaxGenerations(_) => ConvergenceReason::MaxGenerations {
645                generations: generation,
646            },
647            Self::MaxEvaluations(_) => ConvergenceReason::MaxEvaluations { evaluations },
648            Self::TargetFitness(_, _) => ConvergenceReason::target_reached(best_fitness),
649            Self::Stagnation(gens, _) => ConvergenceReason::fitness_stagnation(*gens),
650            Self::DiversityThreshold(_) => ConvergenceReason::low_diversity(diversity),
651            Self::TimeLimit(t) => ConvergenceReason::Custom(format!("Time limit of {t}s reached")),
652            Self::Custom(desc) => ConvergenceReason::Custom(desc.clone()),
653        }
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    #[test]
662    fn test_convergence_detector_basic() {
663        let config = ConvergenceConfig::with_max_generations(100);
664        let mut detector = ConvergenceDetector::new(config);
665
666        // Simulate improving fitness
667        for i in 0..50 {
668            detector.update(i, i * 10, i as f64, i as f64 * 0.5, 0.5);
669        }
670
671        let status = detector.check();
672        assert!(!status.is_converged());
673    }
674
675    #[test]
676    fn test_convergence_detector_max_generations() {
677        let config = ConvergenceConfig::with_max_generations(50);
678        let mut detector = ConvergenceDetector::new(config);
679
680        for i in 0..60 {
681            detector.update(i, i * 10, i as f64, i as f64 * 0.5, 0.5);
682        }
683
684        let status = detector.check();
685        assert!(status.is_converged());
686        if let ConvergenceStatus::Converged(reason) = status {
687            assert!(matches!(reason, ConvergenceReason::MaxGenerations { .. }));
688        }
689    }
690
691    #[test]
692    fn test_convergence_detector_target_fitness() {
693        let config = ConvergenceConfig::default()
694            .target_fitness(100.0)
695            .target_tolerance(1.0);
696        let mut detector = ConvergenceDetector::new(config);
697
698        detector.update(0, 10, 99.5, 50.0, 0.5);
699
700        let status = detector.check();
701        assert!(status.is_converged());
702    }
703
704    #[test]
705    fn test_convergence_detector_stagnation() {
706        let config = ConvergenceConfig::default().stagnation(10, 1e-9);
707        let mut detector = ConvergenceDetector::new(config);
708
709        // First improvement
710        detector.update(0, 10, 50.0, 50.0, 0.5);
711
712        // Then stagnation
713        for i in 1..20 {
714            detector.update(i, i * 10, 50.0, 50.0, 0.5);
715        }
716
717        let status = detector.check();
718        assert!(status.is_converged());
719        if let ConvergenceStatus::Converged(reason) = status {
720            assert!(matches!(
721                reason,
722                ConvergenceReason::FitnessStagnation { .. }
723            ));
724        }
725    }
726
727    #[test]
728    fn test_convergence_detector_low_diversity() {
729        let config = ConvergenceConfig::default().diversity_threshold(0.1);
730        let mut detector = ConvergenceDetector::new(config);
731
732        detector.update(0, 10, 50.0, 50.0, 0.05);
733
734        let status = detector.check();
735        assert!(status.is_converged());
736        if let ConvergenceStatus::Converged(reason) = status {
737            assert!(matches!(reason, ConvergenceReason::LowDiversity { .. }));
738        }
739    }
740
741    #[test]
742    fn test_evolutionary_rhat() {
743        // Similar chains with some variation should give R-hat close to 1
744        let chain1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
745        let chain2 = vec![1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5];
746        let rhat = evolutionary_rhat(&[chain1, chain2]);
747        // R-hat should be close to 1 for similar chains (typically < 1.1 for convergence)
748        assert!(rhat < 1.2, "R-hat was {}, expected < 1.2", rhat);
749    }
750
751    #[test]
752    fn test_evolutionary_rhat_divergent() {
753        // Very different chains with internal variation should give high R-hat
754        let chain1 = vec![1.0, 2.0, 1.5, 2.5, 1.2, 2.8, 1.8, 2.2, 1.3, 2.7];
755        let chain2 = vec![
756            100.0, 101.0, 100.5, 101.5, 100.2, 101.8, 100.8, 101.2, 100.3, 101.7,
757        ];
758        let rhat = evolutionary_rhat(&[chain1, chain2]);
759        // R-hat should be high for divergent chains
760        assert!(rhat > 1.5, "R-hat was {}, expected > 1.5", rhat);
761    }
762
763    #[test]
764    fn test_evolutionary_ess() {
765        // Equal weights should give ESS = n
766        let weights = vec![1.0, 1.0, 1.0, 1.0];
767        let ess = evolutionary_ess(&weights);
768        assert!((ess - 4.0).abs() < 0.01);
769    }
770
771    #[test]
772    fn test_evolutionary_ess_unequal() {
773        // One dominant weight should give low ESS
774        let weights = vec![1.0, 0.0, 0.0, 0.0];
775        let ess = evolutionary_ess(&weights);
776        assert!((ess - 1.0).abs() < 0.01);
777    }
778
779    #[test]
780    fn test_evolutionary_ess_log() {
781        let log_weights = vec![0.0, 0.0, 0.0, 0.0];
782        let ess = evolutionary_ess_log(&log_weights);
783        assert!((ess - 4.0).abs() < 0.01);
784    }
785
786    #[test]
787    fn test_detect_stagnation() {
788        let history = vec![10.0, 20.0, 30.0, 30.0, 30.0, 30.0];
789        let stagnant = detect_stagnation(&history, 1e-9);
790        assert_eq!(stagnant, 3); // 3 generations at max
791    }
792
793    #[test]
794    fn test_detect_stagnation_improving() {
795        let history = vec![10.0, 20.0, 30.0, 40.0, 50.0];
796        let stagnant = detect_stagnation(&history, 1e-9);
797        assert_eq!(stagnant, 0);
798    }
799
800    #[test]
801    fn test_fitness_convergence() {
802        // All same fitness = perfect convergence
803        let fitness = vec![50.0, 50.0, 50.0, 50.0];
804        let conv = fitness_convergence(&fitness);
805        assert!((conv - 1.0).abs() < 0.01);
806    }
807
808    #[test]
809    fn test_fitness_convergence_diverse() {
810        // Very diverse fitness = low convergence
811        let fitness = vec![0.0, 100.0, 0.0, 100.0];
812        let conv = fitness_convergence(&fitness);
813        assert!(conv < 0.5);
814    }
815
816    #[test]
817    fn test_termination_criteria_max_gen() {
818        let criteria = TerminationCriteria::new().max_generations(100);
819
820        let result = criteria.should_terminate(50, 500, 10.0, 0.5, 0, 10.0);
821        assert!(result.is_none());
822
823        let result = criteria.should_terminate(100, 1000, 10.0, 0.5, 0, 20.0);
824        assert!(result.is_some());
825    }
826
827    #[test]
828    fn test_termination_criteria_target() {
829        let criteria = TerminationCriteria::new().target_fitness(100.0, 1.0);
830
831        let result = criteria.should_terminate(10, 100, 50.0, 0.5, 0, 5.0);
832        assert!(result.is_none());
833
834        let result = criteria.should_terminate(10, 100, 99.5, 0.5, 0, 5.0);
835        assert!(result.is_some());
836    }
837
838    #[test]
839    fn test_termination_criteria_multiple() {
840        let criteria = TerminationCriteria::new()
841            .max_generations(100)
842            .target_fitness(100.0, 1.0);
843
844        // Neither met
845        let result = criteria.should_terminate(10, 100, 50.0, 0.5, 0, 5.0);
846        assert!(result.is_none());
847
848        // Target met
849        let result = criteria.should_terminate(10, 100, 100.0, 0.5, 0, 5.0);
850        assert!(result.is_some());
851
852        // Max gen met
853        let result = criteria.should_terminate(100, 1000, 50.0, 0.5, 0, 50.0);
854        assert!(result.is_some());
855    }
856
857    #[test]
858    fn test_termination_criteria_require_all() {
859        let criteria = TerminationCriteria::require_all()
860            .max_generations(100)
861            .stagnation(10, 1e-9);
862
863        // Only max gen met
864        let result = criteria.should_terminate(100, 1000, 50.0, 0.5, 5, 50.0);
865        assert!(result.is_none());
866
867        // Both met
868        let result = criteria.should_terminate(100, 1000, 50.0, 0.5, 10, 50.0);
869        assert!(result.is_some());
870    }
871
872    #[test]
873    fn test_convergence_config_builder() {
874        let config = ConvergenceConfig::with_max_generations(500)
875            .max_evaluations(10000)
876            .target_fitness(1.0)
877            .target_tolerance(0.01)
878            .stagnation(100, 1e-6)
879            .diversity_threshold(0.05)
880            .with_rhat(1.05);
881
882        assert_eq!(config.max_generations, Some(500));
883        assert_eq!(config.max_evaluations, Some(10000));
884        assert_eq!(config.target_fitness, Some(1.0));
885        assert_eq!(config.target_tolerance, 0.01);
886        assert_eq!(config.stagnation_generations, 100);
887        assert_eq!(config.stagnation_threshold, 1e-6);
888        assert_eq!(config.diversity_threshold, 0.05);
889        assert!(config.use_rhat);
890        assert_eq!(config.rhat_threshold, 1.05);
891    }
892
893    #[test]
894    fn test_convergence_detector_reset() {
895        let config = ConvergenceConfig::default();
896        let mut detector = ConvergenceDetector::new(config);
897
898        detector.update(0, 10, 50.0, 50.0, 0.5);
899        detector.update(1, 20, 60.0, 55.0, 0.4);
900
901        assert_eq!(detector.fitness_history().len(), 2);
902        assert_eq!(detector.best_fitness(), 60.0);
903
904        detector.reset();
905
906        assert!(detector.fitness_history().is_empty());
907        assert_eq!(detector.best_fitness(), f64::NEG_INFINITY);
908    }
909
910    #[test]
911    fn test_convergence_status_is_converged() {
912        let not_converged = ConvergenceStatus::NotConverged;
913        assert!(!not_converged.is_converged());
914
915        let converged =
916            ConvergenceStatus::Converged(ConvergenceReason::MaxGenerations { generations: 100 });
917        assert!(converged.is_converged());
918    }
919}