Skip to main content

fugue_evo/hyperparameter/
bayesian.rs

1//! Bayesian hyperparameter learning
2//!
3//! Online Bayesian inference for learning optimal hyperparameters using
4//! conjugate prior distributions.
5
6use rand::Rng;
7use rand_distr::{Beta, Distribution, Gamma};
8use std::collections::VecDeque;
9
10/// Beta distribution posterior for probability parameters (e.g., mutation rate)
11#[derive(Clone, Debug)]
12pub struct BetaPosterior {
13    /// Alpha parameter (pseudo-count of successes)
14    pub alpha: f64,
15    /// Beta parameter (pseudo-count of failures)
16    pub beta: f64,
17}
18
19impl BetaPosterior {
20    /// Create with uniform prior (α = β = 1)
21    pub fn uniform() -> Self {
22        Self {
23            alpha: 1.0,
24            beta: 1.0,
25        }
26    }
27
28    /// Create with Jeffreys prior (α = β = 0.5)
29    pub fn jeffreys() -> Self {
30        Self {
31            alpha: 0.5,
32            beta: 0.5,
33        }
34    }
35
36    /// Create with custom prior
37    pub fn new(alpha: f64, beta: f64) -> Self {
38        Self { alpha, beta }
39    }
40
41    /// Update posterior with a success observation
42    pub fn observe_success(&mut self) {
43        self.alpha += 1.0;
44    }
45
46    /// Update posterior with a failure observation
47    pub fn observe_failure(&mut self) {
48        self.beta += 1.0;
49    }
50
51    /// Update based on boolean outcome
52    pub fn observe(&mut self, success: bool) {
53        if success {
54            self.observe_success();
55        } else {
56            self.observe_failure();
57        }
58    }
59
60    /// Posterior mean
61    pub fn mean(&self) -> f64 {
62        self.alpha / (self.alpha + self.beta)
63    }
64
65    /// Posterior mode (for α, β > 1)
66    pub fn mode(&self) -> Option<f64> {
67        if self.alpha > 1.0 && self.beta > 1.0 {
68            Some((self.alpha - 1.0) / (self.alpha + self.beta - 2.0))
69        } else {
70            None
71        }
72    }
73
74    /// Posterior variance
75    pub fn variance(&self) -> f64 {
76        let sum = self.alpha + self.beta;
77        (self.alpha * self.beta) / (sum * sum * (sum + 1.0))
78    }
79
80    /// Posterior standard deviation
81    pub fn std_dev(&self) -> f64 {
82        self.variance().sqrt()
83    }
84
85    /// Sample from the posterior
86    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
87        Beta::new(self.alpha, self.beta)
88            .expect("Invalid Beta parameters")
89            .sample(rng)
90    }
91
92    /// 95% credible interval (approximate using normal approximation for large counts)
93    pub fn credible_interval(&self, probability: f64) -> (f64, f64) {
94        let mean = self.mean();
95        let std = self.std_dev();
96        let z = normal_quantile((1.0 + probability) / 2.0);
97        let lower = (mean - z * std).max(0.0);
98        let upper = (mean + z * std).min(1.0);
99        (lower, upper)
100    }
101
102    /// Number of observations
103    pub fn observations(&self) -> f64 {
104        self.alpha + self.beta - 2.0 // Subtract prior pseudo-counts for uniform
105    }
106
107    /// Apply decay to move toward prior (for non-stationary environments)
108    pub fn decay(&mut self, factor: f64) {
109        // Move α and β toward 1 (uniform prior)
110        self.alpha = 1.0 + factor * (self.alpha - 1.0);
111        self.beta = 1.0 + factor * (self.beta - 1.0);
112    }
113}
114
115impl Default for BetaPosterior {
116    fn default() -> Self {
117        Self::uniform()
118    }
119}
120
121/// Gamma distribution posterior for positive rate parameters (e.g., temperature)
122#[derive(Clone, Debug)]
123pub struct GammaPosterior {
124    /// Shape parameter (α)
125    pub shape: f64,
126    /// Rate parameter (β)
127    pub rate: f64,
128}
129
130impl GammaPosterior {
131    /// Create with vague prior
132    pub fn vague() -> Self {
133        Self {
134            shape: 1.0,
135            rate: 0.01,
136        }
137    }
138
139    /// Create with custom prior
140    pub fn new(shape: f64, rate: f64) -> Self {
141        Self { shape, rate }
142    }
143
144    /// Update with an observation (conjugate update for exponential likelihood)
145    pub fn observe(&mut self, value: f64) {
146        self.shape += 1.0;
147        self.rate += value;
148    }
149
150    /// Posterior mean
151    pub fn mean(&self) -> f64 {
152        self.shape / self.rate
153    }
154
155    /// Posterior mode (for shape > 1)
156    pub fn mode(&self) -> Option<f64> {
157        if self.shape >= 1.0 {
158            Some((self.shape - 1.0) / self.rate)
159        } else {
160            None
161        }
162    }
163
164    /// Posterior variance
165    pub fn variance(&self) -> f64 {
166        self.shape / (self.rate * self.rate)
167    }
168
169    /// Sample from the posterior
170    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
171        Gamma::new(self.shape, 1.0 / self.rate)
172            .expect("Invalid Gamma parameters")
173            .sample(rng)
174    }
175
176    /// Apply decay
177    pub fn decay(&mut self, factor: f64) {
178        self.shape = 1.0 + factor * (self.shape - 1.0);
179        self.rate = 0.01 + factor * (self.rate - 0.01);
180    }
181}
182
183impl Default for GammaPosterior {
184    fn default() -> Self {
185        Self::vague()
186    }
187}
188
189/// Log-normal posterior approximation for step sizes
190#[derive(Clone, Debug)]
191pub struct LogNormalPosterior {
192    /// Mean of log(σ)
193    pub mu: f64,
194    /// Variance of log(σ)
195    pub sigma_sq: f64,
196    /// Number of observations
197    pub n: usize,
198}
199
200impl LogNormalPosterior {
201    /// Create with vague prior
202    pub fn vague() -> Self {
203        Self {
204            mu: 0.0,
205            sigma_sq: 1.0,
206            n: 0,
207        }
208    }
209
210    /// Create with informative prior
211    pub fn new(mu: f64, sigma_sq: f64) -> Self {
212        Self { mu, sigma_sq, n: 0 }
213    }
214
215    /// Update with an observation of a step size
216    pub fn observe(&mut self, sigma: f64) {
217        if sigma <= 0.0 {
218            return;
219        }
220
221        let log_sigma = sigma.ln();
222        self.n += 1;
223
224        // Online update of mean and variance
225        let delta = log_sigma - self.mu;
226        self.mu += delta / self.n as f64;
227        // Note: This is a simplified update, not fully Bayesian
228        if self.n > 1 {
229            self.sigma_sq = (self.sigma_sq * (self.n - 1) as f64 + delta * (log_sigma - self.mu))
230                / self.n as f64;
231        }
232    }
233
234    /// Mean of the distribution (in original space)
235    pub fn mean(&self) -> f64 {
236        (self.mu + self.sigma_sq / 2.0).exp()
237    }
238
239    /// Mode of the distribution (in original space)
240    pub fn mode(&self) -> f64 {
241        (self.mu - self.sigma_sq).exp()
242    }
243
244    /// Sample from the posterior
245    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
246        use rand_distr::StandardNormal;
247        let z: f64 = rng.sample(StandardNormal);
248        (self.mu + self.sigma_sq.sqrt() * z).exp()
249    }
250}
251
252impl Default for LogNormalPosterior {
253    fn default() -> Self {
254        Self::vague()
255    }
256}
257
258/// Collection of hyperparameter posteriors
259#[derive(Clone, Debug, Default)]
260pub struct HyperparameterPosteriors {
261    /// Mutation rate posterior
262    pub mutation_rate: BetaPosterior,
263    /// Crossover probability posterior
264    pub crossover_prob: BetaPosterior,
265    /// Selection temperature posterior
266    pub temperature: GammaPosterior,
267    /// SBX distribution index posterior
268    pub sbx_eta: GammaPosterior,
269    /// Polynomial mutation eta posterior
270    pub pm_eta: GammaPosterior,
271    /// Step sizes posteriors
272    pub step_sizes: Vec<LogNormalPosterior>,
273}
274
275impl HyperparameterPosteriors {
276    /// Create with default (vague) priors
277    pub fn new() -> Self {
278        Self::default()
279    }
280
281    /// Create with specified number of step size parameters
282    pub fn with_step_sizes(n: usize) -> Self {
283        Self {
284            step_sizes: vec![LogNormalPosterior::vague(); n],
285            ..Default::default()
286        }
287    }
288
289    /// Apply decay to all posteriors
290    pub fn decay_all(&mut self, factor: f64) {
291        self.mutation_rate.decay(factor);
292        self.crossover_prob.decay(factor);
293        self.temperature.decay(factor);
294        self.sbx_eta.decay(factor);
295        self.pm_eta.decay(factor);
296    }
297}
298
299/// Operator parameters that can be learned
300#[derive(Clone, Debug)]
301pub struct OperatorParams {
302    /// Mutation rate
303    pub mutation_rate: f64,
304    /// Crossover probability
305    pub crossover_prob: f64,
306    /// Selection temperature
307    pub temperature: f64,
308    /// SBX distribution index
309    pub sbx_eta: f64,
310    /// Polynomial mutation distribution index
311    pub pm_eta: f64,
312}
313
314impl Default for OperatorParams {
315    fn default() -> Self {
316        Self {
317            mutation_rate: 0.1,
318            crossover_prob: 0.9,
319            temperature: 1.0,
320            sbx_eta: 20.0,
321            pm_eta: 20.0,
322        }
323    }
324}
325
326impl OperatorParams {
327    /// Sample parameters from posteriors
328    pub fn sample_from<R: Rng>(posteriors: &HyperparameterPosteriors, rng: &mut R) -> Self {
329        Self {
330            mutation_rate: posteriors.mutation_rate.sample(rng),
331            crossover_prob: posteriors.crossover_prob.sample(rng),
332            temperature: posteriors.temperature.sample(rng).max(0.01),
333            sbx_eta: posteriors.sbx_eta.sample(rng).max(1.0),
334            pm_eta: posteriors.pm_eta.sample(rng).max(1.0),
335        }
336    }
337
338    /// Get MAP (maximum a posteriori) estimate from posteriors
339    pub fn map_estimate(posteriors: &HyperparameterPosteriors) -> Self {
340        Self {
341            mutation_rate: posteriors
342                .mutation_rate
343                .mode()
344                .unwrap_or(posteriors.mutation_rate.mean()),
345            crossover_prob: posteriors
346                .crossover_prob
347                .mode()
348                .unwrap_or(posteriors.crossover_prob.mean()),
349            temperature: posteriors
350                .temperature
351                .mode()
352                .unwrap_or(posteriors.temperature.mean())
353                .max(0.01),
354            sbx_eta: posteriors
355                .sbx_eta
356                .mode()
357                .unwrap_or(posteriors.sbx_eta.mean())
358                .max(1.0),
359            pm_eta: posteriors
360                .pm_eta
361                .mode()
362                .unwrap_or(posteriors.pm_eta.mean())
363                .max(1.0),
364        }
365    }
366}
367
368/// Online Bayesian hyperparameter learner
369#[derive(Clone, Debug)]
370pub struct BayesianHyperparameterLearner {
371    /// Hyperparameter posteriors
372    pub posteriors: HyperparameterPosteriors,
373    /// Sliding window of observations
374    history: VecDeque<(OperatorParams, f64)>,
375    /// Maximum history size
376    window_size: usize,
377    /// Decay factor for non-stationary environments
378    decay_factor: f64,
379}
380
381impl BayesianHyperparameterLearner {
382    /// Create a new learner
383    pub fn new() -> Self {
384        Self {
385            posteriors: HyperparameterPosteriors::new(),
386            history: VecDeque::new(),
387            window_size: 100,
388            decay_factor: 1.0, // No decay by default
389        }
390    }
391
392    /// Set window size
393    pub fn with_window_size(mut self, size: usize) -> Self {
394        self.window_size = size;
395        self
396    }
397
398    /// Set decay factor (< 1.0 for non-stationary environments)
399    pub fn with_decay(mut self, factor: f64) -> Self {
400        self.decay_factor = factor;
401        self
402    }
403
404    /// Observe the outcome of applying operators with given parameters
405    pub fn observe(&mut self, params: OperatorParams, parent_fitness: f64, child_fitness: f64) {
406        let improvement = child_fitness - parent_fitness;
407        let success = improvement > 0.0;
408
409        // Update posteriors
410        self.posteriors.mutation_rate.observe(success);
411        self.posteriors.crossover_prob.observe(success);
412
413        // For continuous parameters, observe them when successful
414        if success {
415            self.posteriors.temperature.observe(params.temperature);
416            self.posteriors.sbx_eta.observe(params.sbx_eta);
417            self.posteriors.pm_eta.observe(params.pm_eta);
418        }
419
420        // Maintain history
421        self.history.push_back((params, improvement));
422        if self.history.len() > self.window_size {
423            self.history.pop_front();
424        }
425
426        // Apply decay if configured
427        if self.decay_factor < 1.0 {
428            self.posteriors.decay_all(self.decay_factor);
429        }
430    }
431
432    /// Sample parameters from current posteriors
433    pub fn sample_params<R: Rng>(&self, rng: &mut R) -> OperatorParams {
434        OperatorParams::sample_from(&self.posteriors, rng)
435    }
436
437    /// Get MAP estimate of parameters
438    pub fn map_params(&self) -> OperatorParams {
439        OperatorParams::map_estimate(&self.posteriors)
440    }
441
442    /// Get current posteriors
443    pub fn posteriors(&self) -> &HyperparameterPosteriors {
444        &self.posteriors
445    }
446
447    /// Get average improvement in history
448    pub fn average_improvement(&self) -> Option<f64> {
449        if self.history.is_empty() {
450            return None;
451        }
452        let sum: f64 = self.history.iter().map(|(_, imp)| *imp).sum();
453        Some(sum / self.history.len() as f64)
454    }
455
456    /// Reset the learner
457    pub fn reset(&mut self) {
458        self.posteriors = HyperparameterPosteriors::new();
459        self.history.clear();
460    }
461}
462
463impl Default for BayesianHyperparameterLearner {
464    fn default() -> Self {
465        Self::new()
466    }
467}
468
469/// Approximate normal quantile function
470fn normal_quantile(p: f64) -> f64 {
471    // Rational approximation for normal quantile
472    // Good enough for credible interval computation
473    if p <= 0.0 {
474        return f64::NEG_INFINITY;
475    }
476    if p >= 1.0 {
477        return f64::INFINITY;
478    }
479
480    let t = if p < 0.5 {
481        (-2.0 * p.ln()).sqrt()
482    } else {
483        (-2.0 * (1.0 - p).ln()).sqrt()
484    };
485
486    let c0 = 2.515517;
487    let c1 = 0.802853;
488    let c2 = 0.010328;
489    let d1 = 1.432788;
490    let d2 = 0.189269;
491    let d3 = 0.001308;
492
493    let q = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
494
495    if p < 0.5 {
496        -q
497    } else {
498        q
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_beta_posterior_uniform_prior() {
508        let posterior = BetaPosterior::uniform();
509        assert!((posterior.mean() - 0.5).abs() < 1e-10);
510    }
511
512    #[test]
513    fn test_beta_posterior_update() {
514        let mut posterior = BetaPosterior::uniform();
515
516        // Observe 7 successes, 3 failures
517        for _ in 0..7 {
518            posterior.observe_success();
519        }
520        for _ in 0..3 {
521            posterior.observe_failure();
522        }
523
524        // Expected mean: (1 + 7) / (1 + 7 + 1 + 3) = 8/12 = 0.667
525        assert!((posterior.mean() - 0.667).abs() < 0.01);
526    }
527
528    #[test]
529    fn test_beta_posterior_sample() {
530        let posterior = BetaPosterior::new(5.0, 5.0);
531        let mut rng = rand::thread_rng();
532
533        for _ in 0..100 {
534            let sample = posterior.sample(&mut rng);
535            assert!((0.0..=1.0).contains(&sample));
536        }
537    }
538
539    #[test]
540    fn test_gamma_posterior() {
541        let mut posterior = GammaPosterior::vague();
542
543        for _ in 0..10 {
544            posterior.observe(1.0);
545        }
546
547        // Mean should be around 1.0 after observing 1.0s
548        let mean = posterior.mean();
549        assert!(mean > 0.5 && mean < 2.0);
550    }
551
552    #[test]
553    fn test_log_normal_posterior() {
554        let mut posterior = LogNormalPosterior::vague();
555
556        // Observe step sizes around 0.1
557        for _ in 0..10 {
558            posterior.observe(0.1);
559        }
560
561        // Mean should be close to 0.1
562        let mean = posterior.mean();
563        assert!(mean > 0.05 && mean < 0.5);
564    }
565
566    #[test]
567    fn test_bayesian_learner() {
568        let mut learner = BayesianHyperparameterLearner::new();
569        let mut rng = rand::thread_rng();
570
571        // Simulate some observations
572        for i in 0..20 {
573            let params = OperatorParams::default();
574            let parent_fitness = 0.0;
575            let child_fitness = if i % 3 == 0 { 1.0 } else { -1.0 };
576
577            learner.observe(params, parent_fitness, child_fitness);
578        }
579
580        // Should be able to sample params
581        let sampled = learner.sample_params(&mut rng);
582        assert!(sampled.mutation_rate >= 0.0 && sampled.mutation_rate <= 1.0);
583        assert!(sampled.crossover_prob >= 0.0 && sampled.crossover_prob <= 1.0);
584    }
585
586    #[test]
587    fn test_operator_params_sample() {
588        let posteriors = HyperparameterPosteriors::new();
589        let mut rng = rand::thread_rng();
590
591        let params = OperatorParams::sample_from(&posteriors, &mut rng);
592        assert!(params.mutation_rate >= 0.0);
593        assert!(params.temperature > 0.0);
594        assert!(params.sbx_eta >= 1.0);
595    }
596
597    #[test]
598    fn test_credible_interval() {
599        let posterior = BetaPosterior::new(50.0, 50.0);
600        let (lower, upper) = posterior.credible_interval(0.95);
601
602        assert!(lower < 0.5);
603        assert!(upper > 0.5);
604        assert!(lower > 0.0);
605        assert!(upper < 1.0);
606    }
607}