Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Implementing the step method

The step method is responsible for updating the model parameters (weights) according to the specific optimization strategy in use. It uses pattern matching to dispatch the correct behavior depending on whether the optimizer is a GradientDescent or a Momentum variant.

#![allow(unused)]
fn main() {
impl Optimizer {
    /// Applies a single optimization step, depending on the variant.
    ///
    /// # Arguments
    /// - `weights`: Mutable slice of model parameters to be updated.
    /// - `grads`: Gradient slice with same shape as `weights`.
    pub fn step(&mut self, weights: &mut [f64], grads: &[f64]) {
        match self {
            Optimizer::GradientDescent { learning_rate } => {
                for (w, g) in weights.iter_mut().zip(grads.iter()) {
                    *w -= *learning_rate * *g;
                }
            }
            Optimizer::Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                for ((w, g), v) in weights
                    .iter_mut()
                    .zip(grads.iter())
                    .zip(velocity.iter_mut())
                {
                    *v = *momentum * *v + *learning_rate * *g;
                    *w -= *v;
                }
            }
        }
    }
}
}

The match expression identifies which optimizer variant is being used. This pattern can be a clean alternative to trait-based designs when you want:

  • A small number of well-known variants
  • Built-in state encapsulation
  • Exhaustive handling via pattern matching

It keeps related logic grouped under one type and can be extended easily with new optimizers.