Implementing the step method
The step
method is responsible for updating the model parameters (weights
) according to the specific optimization strategy in use. It uses pattern matching to dispatch the correct behavior depending on whether the optimizer is a GradientDescent
or a Momentum
variant.
#![allow(unused)] fn main() { impl Optimizer { /// Applies a single optimization step, depending on the variant. /// /// # Arguments /// - `weights`: Mutable slice of model parameters to be updated. /// - `grads`: Gradient slice with same shape as `weights`. pub fn step(&mut self, weights: &mut [f64], grads: &[f64]) { match self { Optimizer::GradientDescent { learning_rate } => { for (w, g) in weights.iter_mut().zip(grads.iter()) { *w -= *learning_rate * *g; } } Optimizer::Momentum { learning_rate, momentum, velocity, } => { for ((w, g), v) in weights .iter_mut() .zip(grads.iter()) .zip(velocity.iter_mut()) { *v = *momentum * *v + *learning_rate * *g; *w -= *v; } } } } } }
The match
expression identifies which optimizer variant is being used. This pattern can be a clean alternative to trait-based designs when you want:
- A small number of well-known variants
- Built-in state encapsulation
- Exhaustive handling via pattern matching
It keeps related logic grouped under one type and can be extended easily with new optimizers.