1use std::f64::consts::PI;
6
7pub trait ParameterSchedule: Send + Sync {
11 fn value_at(&self, generation: usize, max_generations: usize) -> f64;
13}
14
15#[derive(Clone, Debug)]
17pub struct ConstantSchedule {
18 pub value: f64,
20}
21
22impl ConstantSchedule {
23 pub fn new(value: f64) -> Self {
25 Self { value }
26 }
27}
28
29impl ParameterSchedule for ConstantSchedule {
30 fn value_at(&self, _generation: usize, _max_generations: usize) -> f64 {
31 self.value
32 }
33}
34
35#[derive(Clone, Debug)]
37pub struct LinearAnnealing {
38 pub start: f64,
40 pub end: f64,
42}
43
44impl LinearAnnealing {
45 pub fn new(start: f64, end: f64) -> Self {
47 Self { start, end }
48 }
49
50 pub fn decreasing(start: f64, end: f64) -> Self {
52 Self::new(start, end)
53 }
54
55 pub fn increasing(start: f64, end: f64) -> Self {
57 Self::new(start, end)
58 }
59}
60
61impl ParameterSchedule for LinearAnnealing {
62 fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
63 if max_generations == 0 {
64 return self.start;
65 }
66 let t = generation as f64 / max_generations as f64;
67 self.start + (self.end - self.start) * t
68 }
69}
70
71#[derive(Clone, Debug)]
73pub struct ExponentialDecay {
74 pub initial: f64,
76 pub decay_rate: f64,
78 pub minimum: f64,
80}
81
82impl ExponentialDecay {
83 pub fn new(initial: f64, decay_rate: f64) -> Self {
85 Self {
86 initial,
87 decay_rate,
88 minimum: 0.0,
89 }
90 }
91
92 pub fn with_minimum(mut self, minimum: f64) -> Self {
94 self.minimum = minimum;
95 self
96 }
97}
98
99impl ParameterSchedule for ExponentialDecay {
100 fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
101 (self.initial * (-self.decay_rate * generation as f64).exp()).max(self.minimum)
102 }
103}
104
105#[derive(Clone, Debug)]
109pub struct CosineAnnealing {
110 pub max_value: f64,
112 pub min_value: f64,
114 pub period: Option<usize>,
116}
117
118impl CosineAnnealing {
119 pub fn new(max_value: f64, min_value: f64) -> Self {
121 Self {
122 max_value,
123 min_value,
124 period: None,
125 }
126 }
127
128 pub fn with_warm_restarts(mut self, period: usize) -> Self {
130 self.period = Some(period);
131 self
132 }
133}
134
135impl ParameterSchedule for CosineAnnealing {
136 fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
137 let effective_gen = match self.period {
138 Some(period) if period > 0 => generation % period,
139 _ => generation,
140 };
141 let effective_max = match self.period {
142 Some(period) if period > 0 => period,
143 _ => max_generations,
144 };
145
146 if effective_max == 0 {
147 return self.max_value;
148 }
149
150 let t = effective_gen as f64 / effective_max as f64;
151 self.min_value + 0.5 * (self.max_value - self.min_value) * (1.0 + (PI * t).cos())
152 }
153}
154
155#[derive(Clone, Debug)]
157pub struct StepSchedule {
158 pub steps: Vec<(usize, f64)>,
160 pub initial: f64,
162}
163
164impl StepSchedule {
165 pub fn new(initial: f64, steps: Vec<(usize, f64)>) -> Self {
167 let mut steps = steps;
168 steps.sort_by_key(|(gen, _)| *gen);
169 Self { steps, initial }
170 }
171
172 pub fn single_step(initial: f64, step_gen: usize, step_value: f64) -> Self {
174 Self::new(initial, vec![(step_gen, step_value)])
175 }
176}
177
178impl ParameterSchedule for StepSchedule {
179 fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
180 let mut value = self.initial;
181 for &(step_gen, step_value) in &self.steps {
182 if generation >= step_gen {
183 value = step_value;
184 } else {
185 break;
186 }
187 }
188 value
189 }
190}
191
192#[derive(Clone, Debug)]
194pub struct PolynomialDecay {
195 pub initial: f64,
197 pub power: f64,
199 pub minimum: f64,
201}
202
203impl PolynomialDecay {
204 pub fn new(initial: f64, power: f64) -> Self {
206 Self {
207 initial,
208 power,
209 minimum: 0.0,
210 }
211 }
212
213 pub fn with_minimum(mut self, minimum: f64) -> Self {
215 self.minimum = minimum;
216 self
217 }
218}
219
220impl ParameterSchedule for PolynomialDecay {
221 fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
222 if max_generations == 0 {
223 return self.initial;
224 }
225 let t = generation as f64 / max_generations as f64;
226 let decay = (1.0 - t).max(0.0).powf(self.power);
227 self.minimum + (self.initial - self.minimum) * decay
228 }
229}
230
231#[derive(Clone, Debug)]
233pub struct CyclicalSchedule {
234 pub base: f64,
236 pub max_value: f64,
238 pub step_size: usize,
240}
241
242impl CyclicalSchedule {
243 pub fn new(base: f64, max_value: f64, step_size: usize) -> Self {
245 Self {
246 base,
247 max_value,
248 step_size,
249 }
250 }
251}
252
253impl ParameterSchedule for CyclicalSchedule {
254 fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
255 if self.step_size == 0 {
256 return self.base;
257 }
258
259 let cycle = generation / (2 * self.step_size);
260 let x = (generation as f64 / self.step_size as f64) - 2.0 * cycle as f64;
261 let scale = (1.0 - (x - 1.0).abs()).max(0.0);
262 self.base + (self.max_value - self.base) * scale
263 }
264}
265
266#[derive(Clone, Debug)]
268pub enum DynamicSchedule {
269 Constant(ConstantSchedule),
270 Linear(LinearAnnealing),
271 Exponential(ExponentialDecay),
272 Cosine(CosineAnnealing),
273 Step(StepSchedule),
274 Polynomial(PolynomialDecay),
275 Cyclical(CyclicalSchedule),
276}
277
278impl ParameterSchedule for DynamicSchedule {
279 fn value_at(&self, generation: usize, max_generations: usize) -> f64 {
280 match self {
281 Self::Constant(s) => s.value_at(generation, max_generations),
282 Self::Linear(s) => s.value_at(generation, max_generations),
283 Self::Exponential(s) => s.value_at(generation, max_generations),
284 Self::Cosine(s) => s.value_at(generation, max_generations),
285 Self::Step(s) => s.value_at(generation, max_generations),
286 Self::Polynomial(s) => s.value_at(generation, max_generations),
287 Self::Cyclical(s) => s.value_at(generation, max_generations),
288 }
289 }
290}
291
292impl From<ConstantSchedule> for DynamicSchedule {
293 fn from(s: ConstantSchedule) -> Self {
294 Self::Constant(s)
295 }
296}
297
298impl From<LinearAnnealing> for DynamicSchedule {
299 fn from(s: LinearAnnealing) -> Self {
300 Self::Linear(s)
301 }
302}
303
304impl From<ExponentialDecay> for DynamicSchedule {
305 fn from(s: ExponentialDecay) -> Self {
306 Self::Exponential(s)
307 }
308}
309
310impl From<CosineAnnealing> for DynamicSchedule {
311 fn from(s: CosineAnnealing) -> Self {
312 Self::Cosine(s)
313 }
314}
315
316impl From<StepSchedule> for DynamicSchedule {
317 fn from(s: StepSchedule) -> Self {
318 Self::Step(s)
319 }
320}
321
322impl From<PolynomialDecay> for DynamicSchedule {
323 fn from(s: PolynomialDecay) -> Self {
324 Self::Polynomial(s)
325 }
326}
327
328impl From<CyclicalSchedule> for DynamicSchedule {
329 fn from(s: CyclicalSchedule) -> Self {
330 Self::Cyclical(s)
331 }
332}
333
334#[derive(Clone, Debug)]
336pub struct CompositeSchedule {
337 pub phases: Vec<(usize, DynamicSchedule)>,
339}
340
341impl CompositeSchedule {
342 pub fn new() -> Self {
344 Self { phases: Vec::new() }
345 }
346
347 pub fn add_phase<S: Into<DynamicSchedule>>(mut self, end_gen: usize, schedule: S) -> Self {
349 self.phases.push((end_gen, schedule.into()));
350 self.phases.sort_by_key(|(gen, _)| *gen);
351 self
352 }
353}
354
355impl Default for CompositeSchedule {
356 fn default() -> Self {
357 Self::new()
358 }
359}
360
361impl ParameterSchedule for CompositeSchedule {
362 fn value_at(&self, generation: usize, _max_generations: usize) -> f64 {
363 let mut prev_end = 0;
364 for (end_gen, schedule) in &self.phases {
365 if generation < *end_gen {
366 let phase_duration = end_gen - prev_end;
367 let phase_gen = generation - prev_end;
368 return schedule.value_at(phase_gen, phase_duration);
369 }
370 prev_end = *end_gen;
371 }
372 if let Some((end_gen, schedule)) = self.phases.last() {
374 let phase_duration = end_gen
375 - self
376 .phases
377 .get(self.phases.len().saturating_sub(2))
378 .map(|(e, _)| *e)
379 .unwrap_or(0);
380 schedule.value_at(phase_duration, phase_duration)
381 } else {
382 0.0
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use approx::assert_relative_eq;
391
392 #[test]
393 fn test_constant_schedule() {
394 let schedule = ConstantSchedule::new(0.5);
395 assert_relative_eq!(schedule.value_at(0, 100), 0.5);
396 assert_relative_eq!(schedule.value_at(50, 100), 0.5);
397 assert_relative_eq!(schedule.value_at(100, 100), 0.5);
398 }
399
400 #[test]
401 fn test_linear_annealing() {
402 let schedule = LinearAnnealing::new(1.0, 0.0);
403 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
404 assert_relative_eq!(schedule.value_at(50, 100), 0.5);
405 assert_relative_eq!(schedule.value_at(100, 100), 0.0);
406 }
407
408 #[test]
409 fn test_linear_annealing_increasing() {
410 let schedule = LinearAnnealing::increasing(0.1, 0.9);
411 assert_relative_eq!(schedule.value_at(0, 100), 0.1);
412 assert_relative_eq!(schedule.value_at(100, 100), 0.9);
413 }
414
415 #[test]
416 fn test_exponential_decay() {
417 let schedule = ExponentialDecay::new(1.0, 0.1);
418 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
419 assert!(schedule.value_at(10, 100) < 1.0);
420 assert!(schedule.value_at(50, 100) < schedule.value_at(10, 100));
421 }
422
423 #[test]
424 fn test_exponential_decay_with_minimum() {
425 let schedule = ExponentialDecay::new(1.0, 0.1).with_minimum(0.1);
426 assert!(schedule.value_at(1000, 100) >= 0.1);
427 }
428
429 #[test]
430 fn test_cosine_annealing() {
431 let schedule = CosineAnnealing::new(1.0, 0.0);
432 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
433 assert_relative_eq!(schedule.value_at(100, 100), 0.0, epsilon = 1e-10);
434 assert_relative_eq!(schedule.value_at(50, 100), 0.5, epsilon = 1e-10);
436 }
437
438 #[test]
439 fn test_cosine_annealing_warm_restarts() {
440 let schedule = CosineAnnealing::new(1.0, 0.0).with_warm_restarts(50);
441 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
442 assert_relative_eq!(schedule.value_at(50, 100), 1.0); assert_relative_eq!(schedule.value_at(25, 100), 0.5, epsilon = 1e-10);
444 }
445
446 #[test]
447 fn test_step_schedule() {
448 let schedule = StepSchedule::new(1.0, vec![(25, 0.5), (75, 0.1)]);
449 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
450 assert_relative_eq!(schedule.value_at(24, 100), 1.0);
451 assert_relative_eq!(schedule.value_at(25, 100), 0.5);
452 assert_relative_eq!(schedule.value_at(74, 100), 0.5);
453 assert_relative_eq!(schedule.value_at(75, 100), 0.1);
454 }
455
456 #[test]
457 fn test_polynomial_decay() {
458 let schedule = PolynomialDecay::new(1.0, 2.0).with_minimum(0.0);
459 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
460 assert_relative_eq!(schedule.value_at(100, 100), 0.0);
461 assert_relative_eq!(schedule.value_at(50, 100), 0.25);
463 }
464
465 #[test]
466 fn test_cyclical_schedule() {
467 let schedule = CyclicalSchedule::new(0.0, 1.0, 10);
468 assert_relative_eq!(schedule.value_at(0, 100), 0.0);
469 assert_relative_eq!(schedule.value_at(10, 100), 1.0);
470 assert_relative_eq!(schedule.value_at(20, 100), 0.0);
471 assert_relative_eq!(schedule.value_at(30, 100), 1.0);
472 }
473
474 #[test]
475 fn test_linear_annealing_decreasing() {
476 let schedule = LinearAnnealing::decreasing(0.9, 0.1);
477 assert_relative_eq!(schedule.value_at(0, 100), 0.9);
478 assert_relative_eq!(schedule.value_at(100, 100), 0.1);
479 }
480
481 #[test]
482 fn test_linear_annealing_zero_max_generations() {
483 let schedule = LinearAnnealing::new(1.0, 0.0);
484 assert_relative_eq!(schedule.value_at(0, 0), 1.0);
485 }
486
487 #[test]
488 fn test_step_schedule_single_step() {
489 let schedule = StepSchedule::single_step(1.0, 50, 0.5);
490 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
491 assert_relative_eq!(schedule.value_at(49, 100), 1.0);
492 assert_relative_eq!(schedule.value_at(50, 100), 0.5);
493 assert_relative_eq!(schedule.value_at(100, 100), 0.5);
494 }
495
496 #[test]
497 fn test_polynomial_decay_zero_max_generations() {
498 let schedule = PolynomialDecay::new(1.0, 2.0);
499 assert_relative_eq!(schedule.value_at(0, 0), 1.0);
500 }
501
502 #[test]
503 fn test_cyclical_schedule_zero_step_size() {
504 let schedule = CyclicalSchedule::new(0.5, 1.0, 0);
505 assert_relative_eq!(schedule.value_at(0, 100), 0.5);
506 assert_relative_eq!(schedule.value_at(50, 100), 0.5);
507 }
508
509 #[test]
510 fn test_cosine_annealing_zero_max_generations() {
511 let schedule = CosineAnnealing::new(1.0, 0.0);
512 assert_relative_eq!(schedule.value_at(0, 0), 1.0);
513 }
514
515 #[test]
516 fn test_cosine_annealing_warm_restarts_zero_period() {
517 let schedule = CosineAnnealing::new(1.0, 0.0).with_warm_restarts(0);
518 assert_relative_eq!(schedule.value_at(50, 100), 0.5, epsilon = 1e-10);
520 }
521
522 #[test]
523 fn test_dynamic_schedule_from_conversions() {
524 let constant: DynamicSchedule = ConstantSchedule::new(0.5).into();
525 assert_relative_eq!(constant.value_at(50, 100), 0.5);
526
527 let linear: DynamicSchedule = LinearAnnealing::new(1.0, 0.0).into();
528 assert_relative_eq!(linear.value_at(50, 100), 0.5);
529
530 let exponential: DynamicSchedule = ExponentialDecay::new(1.0, 0.1).into();
531 assert!(exponential.value_at(10, 100) < 1.0);
532
533 let cosine: DynamicSchedule = CosineAnnealing::new(1.0, 0.0).into();
534 assert_relative_eq!(cosine.value_at(50, 100), 0.5, epsilon = 1e-10);
535
536 let step: DynamicSchedule = StepSchedule::new(1.0, vec![(50, 0.5)]).into();
537 assert_relative_eq!(step.value_at(50, 100), 0.5);
538
539 let polynomial: DynamicSchedule = PolynomialDecay::new(1.0, 2.0).into();
540 assert_relative_eq!(polynomial.value_at(50, 100), 0.25);
541
542 let cyclical: DynamicSchedule = CyclicalSchedule::new(0.0, 1.0, 10).into();
543 assert_relative_eq!(cyclical.value_at(10, 100), 1.0);
544 }
545
546 #[test]
547 fn test_composite_schedule() {
548 let schedule = CompositeSchedule::new()
549 .add_phase(50, ConstantSchedule::new(1.0))
550 .add_phase(100, LinearAnnealing::new(1.0, 0.0));
551
552 assert_relative_eq!(schedule.value_at(0, 100), 1.0);
554 assert_relative_eq!(schedule.value_at(25, 100), 1.0);
555
556 assert_relative_eq!(schedule.value_at(50, 100), 1.0);
558 assert_relative_eq!(schedule.value_at(75, 100), 0.5);
559 }
560
561 #[test]
562 fn test_composite_schedule_empty() {
563 let schedule = CompositeSchedule::default();
564 assert_relative_eq!(schedule.value_at(50, 100), 0.0);
565 }
566
567 #[test]
568 fn test_composite_schedule_past_all_phases() {
569 let schedule = CompositeSchedule::new().add_phase(50, ConstantSchedule::new(0.5));
570
571 assert_relative_eq!(schedule.value_at(100, 100), 0.5);
573 }
574}