Skip to main content

fugue_evo/hyperparameter/
schedules.rs

1//! Parameter schedules for deterministic control
2//!
3//! Schedules provide predetermined parameter values based on generation number.
4
5use std::f64::consts::PI;
6
7/// Parameter schedule trait
8///
9/// Defines how a parameter changes over the course of evolution.
10pub trait ParameterSchedule: Send + Sync {
11    /// Get the parameter value at a given generation
12    fn value_at(&self, generation: usize, max_generations: usize) -> f64;
13}
14
15/// Constant parameter (no change)
16#[derive(Clone, Debug)]
17pub struct ConstantSchedule {
18    /// The constant value
19    pub value: f64,
20}
21
22impl ConstantSchedule {
23    /// Create a new constant schedule
24    pub fn new(value: f64) -> Self {
25        Self { value }
26    }
27}
28
29impl ParameterSchedule for ConstantSchedule {
30    fn value_at(&self, _generation: usize, _max_generations: usize) -> f64 {
31        self.value
32    }
33}
34
35/// Linear annealing: p(t) = p_start + (p_end - p_start) * t / T
36#[derive(Clone, Debug)]
37pub struct LinearAnnealing {
38    /// Starting value
39    pub start: f64,
40    /// Ending value
41    pub end: f64,
42}
43
44impl LinearAnnealing {
45    /// Create a new linear annealing schedule
46    pub fn new(start: f64, end: f64) -> Self {
47        Self { start, end }
48    }
49
50    /// Create a schedule that decreases from start to end
51    pub fn decreasing(start: f64, end: f64) -> Self {
52        Self::new(start, end)
53    }
54
55    /// Create a schedule that increases from start to end
56    pub fn increasing(start: f64, end: f64) -> Self {
57        Self::new(start, end)
58    }
59}
60
61impl ParameterSchedule for LinearAnnealing {
62    fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
63        if max_generations == 0 {
64            return self.start;
65        }
66        let t = generation as f64 / max_generations as f64;
67        self.start + (self.end - self.start) * t
68    }
69}
70
71/// Exponential decay: p(t) = p₀ * e^(-λt)
72#[derive(Clone, Debug)]
73pub struct ExponentialDecay {
74    /// Initial value
75    pub initial: f64,
76    /// Decay rate (λ)
77    pub decay_rate: f64,
78    /// Minimum value (floor)
79    pub minimum: f64,
80}
81
82impl ExponentialDecay {
83    /// Create a new exponential decay schedule
84    pub fn new(initial: f64, decay_rate: f64) -> Self {
85        Self {
86            initial,
87            decay_rate,
88            minimum: 0.0,
89        }
90    }
91
92    /// Set the minimum value
93    pub fn with_minimum(mut self, minimum: f64) -> Self {
94        self.minimum = minimum;
95        self
96    }
97}
98
99impl ParameterSchedule for ExponentialDecay {
100    fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
101        (self.initial * (-self.decay_rate * generation as f64).exp()).max(self.minimum)
102    }
103}
104
105/// Cosine annealing with optional warm restarts
106///
107/// p(t) = p_min + 0.5 * (p_max - p_min) * (1 + cos(π * t / T))
108#[derive(Clone, Debug)]
109pub struct CosineAnnealing {
110    /// Maximum value
111    pub max_value: f64,
112    /// Minimum value
113    pub min_value: f64,
114    /// Period for warm restarts (None = single annealing)
115    pub period: Option<usize>,
116}
117
118impl CosineAnnealing {
119    /// Create a new cosine annealing schedule
120    pub fn new(max_value: f64, min_value: f64) -> Self {
121        Self {
122            max_value,
123            min_value,
124            period: None,
125        }
126    }
127
128    /// Enable warm restarts with given period
129    pub fn with_warm_restarts(mut self, period: usize) -> Self {
130        self.period = Some(period);
131        self
132    }
133}
134
135impl ParameterSchedule for CosineAnnealing {
136    fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
137        let effective_gen = match self.period {
138            Some(period) if period > 0 => generation % period,
139            _ => generation,
140        };
141        let effective_max = match self.period {
142            Some(period) if period > 0 => period,
143            _ => max_generations,
144        };
145
146        if effective_max == 0 {
147            return self.max_value;
148        }
149
150        let t = effective_gen as f64 / effective_max as f64;
151        self.min_value + 0.5 * (self.max_value - self.min_value) * (1.0 + (PI * t).cos())
152    }
153}
154
155/// Step schedule: changes at specific generations
156#[derive(Clone, Debug)]
157pub struct StepSchedule {
158    /// List of (generation, value) pairs, sorted by generation
159    pub steps: Vec<(usize, f64)>,
160    /// Initial value (before first step)
161    pub initial: f64,
162}
163
164impl StepSchedule {
165    /// Create a new step schedule
166    pub fn new(initial: f64, steps: Vec<(usize, f64)>) -> Self {
167        let mut steps = steps;
168        steps.sort_by_key(|(gen, _)| *gen);
169        Self { steps, initial }
170    }
171
172    /// Create a schedule with a single step
173    pub fn single_step(initial: f64, step_gen: usize, step_value: f64) -> Self {
174        Self::new(initial, vec![(step_gen, step_value)])
175    }
176}
177
178impl ParameterSchedule for StepSchedule {
179    fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
180        let mut value = self.initial;
181        for &(step_gen, step_value) in &self.steps {
182            if generation >= step_gen {
183                value = step_value;
184            } else {
185                break;
186            }
187        }
188        value
189    }
190}
191
192/// Polynomial decay: p(t) = p₀ * (1 - t/T)^power + p_min
193#[derive(Clone, Debug)]
194pub struct PolynomialDecay {
195    /// Initial value
196    pub initial: f64,
197    /// Power of the polynomial
198    pub power: f64,
199    /// Minimum value at the end
200    pub minimum: f64,
201}
202
203impl PolynomialDecay {
204    /// Create a new polynomial decay schedule
205    pub fn new(initial: f64, power: f64) -> Self {
206        Self {
207            initial,
208            power,
209            minimum: 0.0,
210        }
211    }
212
213    /// Set the minimum value
214    pub fn with_minimum(mut self, minimum: f64) -> Self {
215        self.minimum = minimum;
216        self
217    }
218}
219
220impl ParameterSchedule for PolynomialDecay {
221    fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
222        if max_generations == 0 {
223            return self.initial;
224        }
225        let t = generation as f64 / max_generations as f64;
226        let decay = (1.0 - t).max(0.0).powf(self.power);
227        self.minimum + (self.initial - self.minimum) * decay
228    }
229}
230
231/// Cyclical schedule with triangular waves
232#[derive(Clone, Debug)]
233pub struct CyclicalSchedule {
234    /// Base (minimum) value
235    pub base: f64,
236    /// Maximum value
237    pub max_value: f64,
238    /// Step size (generations per half cycle)
239    pub step_size: usize,
240}
241
242impl CyclicalSchedule {
243    /// Create a new cyclical schedule
244    pub fn new(base: f64, max_value: f64, step_size: usize) -> Self {
245        Self {
246            base,
247            max_value,
248            step_size,
249        }
250    }
251}
252
253impl ParameterSchedule for CyclicalSchedule {
254    fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
255        if self.step_size == 0 {
256            return self.base;
257        }
258
259        let cycle = generation / (2 * self.step_size);
260        let x = (generation as f64 / self.step_size as f64) - 2.0 * cycle as f64;
261        let scale = (1.0 - (x - 1.0).abs()).max(0.0);
262        self.base + (self.max_value - self.base) * scale
263    }
264}
265
266/// Enum-based schedule for when you need to combine different schedule types
267#[derive(Clone, Debug)]
268pub enum DynamicSchedule {
269    Constant(ConstantSchedule),
270    Linear(LinearAnnealing),
271    Exponential(ExponentialDecay),
272    Cosine(CosineAnnealing),
273    Step(StepSchedule),
274    Polynomial(PolynomialDecay),
275    Cyclical(CyclicalSchedule),
276}
277
278impl ParameterSchedule for DynamicSchedule {
279    fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
280        match self {
281            Self::Constant(s) => s.value_at(generation, max_generations),
282            Self::Linear(s) => s.value_at(generation, max_generations),
283            Self::Exponential(s) => s.value_at(generation, max_generations),
284            Self::Cosine(s) => s.value_at(generation, max_generations),
285            Self::Step(s) => s.value_at(generation, max_generations),
286            Self::Polynomial(s) => s.value_at(generation, max_generations),
287            Self::Cyclical(s) => s.value_at(generation, max_generations),
288        }
289    }
290}
291
292impl From<ConstantSchedule> for DynamicSchedule {
293    fn from(s: ConstantSchedule) -> Self {
294        Self::Constant(s)
295    }
296}
297
298impl From<LinearAnnealing> for DynamicSchedule {
299    fn from(s: LinearAnnealing) -> Self {
300        Self::Linear(s)
301    }
302}
303
304impl From<ExponentialDecay> for DynamicSchedule {
305    fn from(s: ExponentialDecay) -> Self {
306        Self::Exponential(s)
307    }
308}
309
310impl From<CosineAnnealing> for DynamicSchedule {
311    fn from(s: CosineAnnealing) -> Self {
312        Self::Cosine(s)
313    }
314}
315
316impl From<StepSchedule> for DynamicSchedule {
317    fn from(s: StepSchedule) -> Self {
318        Self::Step(s)
319    }
320}
321
322impl From<PolynomialDecay> for DynamicSchedule {
323    fn from(s: PolynomialDecay) -> Self {
324        Self::Polynomial(s)
325    }
326}
327
328impl From<CyclicalSchedule> for DynamicSchedule {
329    fn from(s: CyclicalSchedule) -> Self {
330        Self::Cyclical(s)
331    }
332}
333
334/// Composite schedule using enum phases
335#[derive(Clone, Debug)]
336pub struct CompositeSchedule {
337    /// List of (end_generation, schedule) pairs
338    pub phases: Vec<(usize, DynamicSchedule)>,
339}
340
341impl CompositeSchedule {
342    /// Create a new composite schedule
343    pub fn new() -> Self {
344        Self { phases: Vec::new() }
345    }
346
347    /// Add a phase
348    pub fn add_phase<S: Into<DynamicSchedule>>(mut self, end_gen: usize, schedule: S) -> Self {
349        self.phases.push((end_gen, schedule.into()));
350        self.phases.sort_by_key(|(gen, _)| *gen);
351        self
352    }
353}
354
355impl Default for CompositeSchedule {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361impl ParameterSchedule for CompositeSchedule {
362    fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
363        let mut prev_end = 0;
364        for (end_gen, schedule) in &self.phases {
365            if generation < *end_gen {
366                let phase_duration = end_gen - prev_end;
367                let phase_gen = generation - prev_end;
368                return schedule.value_at(phase_gen, phase_duration);
369            }
370            prev_end = *end_gen;
371        }
372        // If past all phases, use the last phase's final value
373        if let Some((end_gen, schedule)) = self.phases.last() {
374            let phase_duration = end_gen
375                - self
376                    .phases
377                    .get(self.phases.len().saturating_sub(2))
378                    .map(|(e, _)| *e)
379                    .unwrap_or(0);
380            schedule.value_at(phase_duration, phase_duration)
381        } else {
382            0.0
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use approx::assert_relative_eq;
391
392    #[test]
393    fn test_constant_schedule() {
394        let schedule = ConstantSchedule::new(0.5);
395        assert_relative_eq!(schedule.value_at(0, 100), 0.5);
396        assert_relative_eq!(schedule.value_at(50, 100), 0.5);
397        assert_relative_eq!(schedule.value_at(100, 100), 0.5);
398    }
399
400    #[test]
401    fn test_linear_annealing() {
402        let schedule = LinearAnnealing::new(1.0, 0.0);
403        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
404        assert_relative_eq!(schedule.value_at(50, 100), 0.5);
405        assert_relative_eq!(schedule.value_at(100, 100), 0.0);
406    }
407
408    #[test]
409    fn test_linear_annealing_increasing() {
410        let schedule = LinearAnnealing::increasing(0.1, 0.9);
411        assert_relative_eq!(schedule.value_at(0, 100), 0.1);
412        assert_relative_eq!(schedule.value_at(100, 100), 0.9);
413    }
414
415    #[test]
416    fn test_exponential_decay() {
417        let schedule = ExponentialDecay::new(1.0, 0.1);
418        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
419        assert!(schedule.value_at(10, 100) < 1.0);
420        assert!(schedule.value_at(50, 100) < schedule.value_at(10, 100));
421    }
422
423    #[test]
424    fn test_exponential_decay_with_minimum() {
425        let schedule = ExponentialDecay::new(1.0, 0.1).with_minimum(0.1);
426        assert!(schedule.value_at(1000, 100) >= 0.1);
427    }
428
429    #[test]
430    fn test_cosine_annealing() {
431        let schedule = CosineAnnealing::new(1.0, 0.0);
432        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
433        assert_relative_eq!(schedule.value_at(100, 100), 0.0, epsilon = 1e-10);
434        // Mid-point should be halfway between max and min
435        assert_relative_eq!(schedule.value_at(50, 100), 0.5, epsilon = 1e-10);
436    }
437
438    #[test]
439    fn test_cosine_annealing_warm_restarts() {
440        let schedule = CosineAnnealing::new(1.0, 0.0).with_warm_restarts(50);
441        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
442        assert_relative_eq!(schedule.value_at(50, 100), 1.0); // Restart
443        assert_relative_eq!(schedule.value_at(25, 100), 0.5, epsilon = 1e-10);
444    }
445
446    #[test]
447    fn test_step_schedule() {
448        let schedule = StepSchedule::new(1.0, vec![(25, 0.5), (75, 0.1)]);
449        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
450        assert_relative_eq!(schedule.value_at(24, 100), 1.0);
451        assert_relative_eq!(schedule.value_at(25, 100), 0.5);
452        assert_relative_eq!(schedule.value_at(74, 100), 0.5);
453        assert_relative_eq!(schedule.value_at(75, 100), 0.1);
454    }
455
456    #[test]
457    fn test_polynomial_decay() {
458        let schedule = PolynomialDecay::new(1.0, 2.0).with_minimum(0.0);
459        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
460        assert_relative_eq!(schedule.value_at(100, 100), 0.0);
461        // Quadratic decay: at t=0.5, value = (1-0.5)^2 = 0.25
462        assert_relative_eq!(schedule.value_at(50, 100), 0.25);
463    }
464
465    #[test]
466    fn test_cyclical_schedule() {
467        let schedule = CyclicalSchedule::new(0.0, 1.0, 10);
468        assert_relative_eq!(schedule.value_at(0, 100), 0.0);
469        assert_relative_eq!(schedule.value_at(10, 100), 1.0);
470        assert_relative_eq!(schedule.value_at(20, 100), 0.0);
471        assert_relative_eq!(schedule.value_at(30, 100), 1.0);
472    }
473
474    #[test]
475    fn test_linear_annealing_decreasing() {
476        let schedule = LinearAnnealing::decreasing(0.9, 0.1);
477        assert_relative_eq!(schedule.value_at(0, 100), 0.9);
478        assert_relative_eq!(schedule.value_at(100, 100), 0.1);
479    }
480
481    #[test]
482    fn test_linear_annealing_zero_max_generations() {
483        let schedule = LinearAnnealing::new(1.0, 0.0);
484        assert_relative_eq!(schedule.value_at(0, 0), 1.0);
485    }
486
487    #[test]
488    fn test_step_schedule_single_step() {
489        let schedule = StepSchedule::single_step(1.0, 50, 0.5);
490        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
491        assert_relative_eq!(schedule.value_at(49, 100), 1.0);
492        assert_relative_eq!(schedule.value_at(50, 100), 0.5);
493        assert_relative_eq!(schedule.value_at(100, 100), 0.5);
494    }
495
496    #[test]
497    fn test_polynomial_decay_zero_max_generations() {
498        let schedule = PolynomialDecay::new(1.0, 2.0);
499        assert_relative_eq!(schedule.value_at(0, 0), 1.0);
500    }
501
502    #[test]
503    fn test_cyclical_schedule_zero_step_size() {
504        let schedule = CyclicalSchedule::new(0.5, 1.0, 0);
505        assert_relative_eq!(schedule.value_at(0, 100), 0.5);
506        assert_relative_eq!(schedule.value_at(50, 100), 0.5);
507    }
508
509    #[test]
510    fn test_cosine_annealing_zero_max_generations() {
511        let schedule = CosineAnnealing::new(1.0, 0.0);
512        assert_relative_eq!(schedule.value_at(0, 0), 1.0);
513    }
514
515    #[test]
516    fn test_cosine_annealing_warm_restarts_zero_period() {
517        let schedule = CosineAnnealing::new(1.0, 0.0).with_warm_restarts(0);
518        // Period 0 should be treated same as no warm restarts
519        assert_relative_eq!(schedule.value_at(50, 100), 0.5, epsilon = 1e-10);
520    }
521
522    #[test]
523    fn test_dynamic_schedule_from_conversions() {
524        let constant: DynamicSchedule = ConstantSchedule::new(0.5).into();
525        assert_relative_eq!(constant.value_at(50, 100), 0.5);
526
527        let linear: DynamicSchedule = LinearAnnealing::new(1.0, 0.0).into();
528        assert_relative_eq!(linear.value_at(50, 100), 0.5);
529
530        let exponential: DynamicSchedule = ExponentialDecay::new(1.0, 0.1).into();
531        assert!(exponential.value_at(10, 100) < 1.0);
532
533        let cosine: DynamicSchedule = CosineAnnealing::new(1.0, 0.0).into();
534        assert_relative_eq!(cosine.value_at(50, 100), 0.5, epsilon = 1e-10);
535
536        let step: DynamicSchedule = StepSchedule::new(1.0, vec![(50, 0.5)]).into();
537        assert_relative_eq!(step.value_at(50, 100), 0.5);
538
539        let polynomial: DynamicSchedule = PolynomialDecay::new(1.0, 2.0).into();
540        assert_relative_eq!(polynomial.value_at(50, 100), 0.25);
541
542        let cyclical: DynamicSchedule = CyclicalSchedule::new(0.0, 1.0, 10).into();
543        assert_relative_eq!(cyclical.value_at(10, 100), 1.0);
544    }
545
546    #[test]
547    fn test_composite_schedule() {
548        let schedule = CompositeSchedule::new()
549            .add_phase(50, ConstantSchedule::new(1.0))
550            .add_phase(100, LinearAnnealing::new(1.0, 0.0));
551
552        // First phase: constant 1.0
553        assert_relative_eq!(schedule.value_at(0, 100), 1.0);
554        assert_relative_eq!(schedule.value_at(25, 100), 1.0);
555
556        // Second phase: linear 1.0 -> 0.0
557        assert_relative_eq!(schedule.value_at(50, 100), 1.0);
558        assert_relative_eq!(schedule.value_at(75, 100), 0.5);
559    }
560
561    #[test]
562    fn test_composite_schedule_empty() {
563        let schedule = CompositeSchedule::default();
564        assert_relative_eq!(schedule.value_at(50, 100), 0.0);
565    }
566
567    #[test]
568    fn test_composite_schedule_past_all_phases() {
569        let schedule = CompositeSchedule::new().add_phase(50, ConstantSchedule::new(0.5));
570
571        // Past the end of phases
572        assert_relative_eq!(schedule.value_at(100, 100), 0.5);
573    }
574}