Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Optimizers as enums with internal state and methods

This chapter builds on the previous enum-based optimizer design. We now give each variant its own internal state and encapsulate behavior using methods. This pattern is useful when you want enum-based control flow with encapsulated logic.

Defining the optimizer enum

Each optimizer variant includes its own parameters and, when needed, its internal state.

#![allow(unused)]
fn main() {
/// An enum representing different optimizers with built-in state and update rules.
///
/// Supports both gradient descent and momentum-based methods.
#[derive(Debug, Clone)]
pub enum Optimizer {
    /// Gradient Descent optimizer with a fixed learning rate.
    GradientDescent { learning_rate: f64 },
    /// Momentum-based optimizer with velocity tracking.
    Momentum {
        learning_rate: f64,
        momentum: f64,
        velocity: Vec<f64>,
    },
}
}

Here, GradientDescent stores only the learning rate, while Momentum additionally stores its velocity vector.

Constructors

We define convenience constructors for each optimizer. These make usage simpler and avoid manually writing match arms.

#![allow(unused)]
fn main() {
impl Optimizer {
    /// Creates a new Gradient Descent optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size for the parameter updates.
    pub fn gradient_descent(learning_rate: f64) -> Self {
        Self::GradientDescent { learning_rate }
    }

    /// Creates a new Momentum optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size for the updates.
    /// - `momentum`: Momentum coefficient.
    /// - `dim`: Number of parameters (used to initialize velocity vector).
    pub fn momentum(learning_rate: f64, momentum: f64, dim: usize) -> Self {
        Self::Momentum {
            learning_rate,
            momentum,
            velocity: vec![0.0; dim],
        }
    }
}
}

This helps create optimizers in a more idiomatic and clean way.

Implementing the step method

The step method applies one optimization update depending on the variant. The method uses pattern matching to extract variant-specific behavior.

#![allow(unused)]
fn main() {
impl Optimizer {
    /// Applies a single optimization step, depending on the variant.
    ///
    /// # Arguments
    /// - `weights`: Mutable slice of model parameters to be updated.
    /// - `grads`: Gradient slice with same shape as `weights`.
    pub fn step(&mut self, weights: &mut [f64], grads: &[f64]) {
        match self {
            Optimizer::GradientDescent { learning_rate } => {
                for (w, g) in weights.iter_mut().zip(grads.iter()) {
                    *w -= *learning_rate * *g;
                }
            }
            Optimizer::Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                for ((w, g), v) in weights
                    .iter_mut()
                    .zip(grads.iter())
                    .zip(velocity.iter_mut())
                {
                    *v = *momentum * *v + *learning_rate * *g;
                    *w -= *v;
                }
            }
        }
    }
}
}

Here, GradientDescent simply applies the learning rate times the gradient. The Momentum variant updates and stores the velocity vector before updating the weights.

Summary

This pattern can be a clean alternative to trait-based designs when you want:

  • A small number of well-known variants
  • Built-in state encapsulation
  • Exhaustive handling via pattern matching

It keeps related logic grouped under one type and can be extended easily with new optimizers.