Optimizers as enums with internal state and methods
This chapter builds on the previous enum-based optimizer design. We now give each variant its own internal state and encapsulate behavior using methods. This pattern is useful when you want enum-based control flow with encapsulated logic.
Defining the optimizer enum
Each optimizer variant includes its own parameters and, when needed, its internal state.
#![allow(unused)] fn main() { /// An enum representing different optimizers with built-in state and update rules. /// /// Supports both gradient descent and momentum-based methods. #[derive(Debug, Clone)] pub enum Optimizer { /// Gradient Descent optimizer with a fixed learning rate. GradientDescent { learning_rate: f64 }, /// Momentum-based optimizer with velocity tracking. Momentum { learning_rate: f64, momentum: f64, velocity: Vec<f64>, }, } }
Here, GradientDescent
stores only the learning rate, while Momentum
additionally stores its velocity vector.
Constructors
We define convenience constructors for each optimizer. These make usage simpler and avoid manually writing match arms.
#![allow(unused)] fn main() { impl Optimizer { /// Creates a new Gradient Descent optimizer. /// /// # Arguments /// - `learning_rate`: Step size for the parameter updates. pub fn gradient_descent(learning_rate: f64) -> Self { Self::GradientDescent { learning_rate } } /// Creates a new Momentum optimizer. /// /// # Arguments /// - `learning_rate`: Step size for the updates. /// - `momentum`: Momentum coefficient. /// - `dim`: Number of parameters (used to initialize velocity vector). pub fn momentum(learning_rate: f64, momentum: f64, dim: usize) -> Self { Self::Momentum { learning_rate, momentum, velocity: vec![0.0; dim], } } } }
This helps create optimizers in a more idiomatic and clean way.
Implementing the step method
The step
method applies one optimization update depending on the variant. The method uses pattern matching to extract variant-specific behavior.
#![allow(unused)] fn main() { impl Optimizer { /// Applies a single optimization step, depending on the variant. /// /// # Arguments /// - `weights`: Mutable slice of model parameters to be updated. /// - `grads`: Gradient slice with same shape as `weights`. pub fn step(&mut self, weights: &mut [f64], grads: &[f64]) { match self { Optimizer::GradientDescent { learning_rate } => { for (w, g) in weights.iter_mut().zip(grads.iter()) { *w -= *learning_rate * *g; } } Optimizer::Momentum { learning_rate, momentum, velocity, } => { for ((w, g), v) in weights .iter_mut() .zip(grads.iter()) .zip(velocity.iter_mut()) { *v = *momentum * *v + *learning_rate * *g; *w -= *v; } } } } } }
Here, GradientDescent
simply applies the learning rate times the gradient. The Momentum
variant updates and stores the velocity vector before updating the weights.
Summary
This pattern can be a clean alternative to trait-based designs when you want:
- A small number of well-known variants
- Built-in state encapsulation
- Exhaustive handling via pattern matching
It keeps related logic grouped under one type and can be extended easily with new optimizers.