Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Gradient descent with momentum

Recall that the gradient descent algorithm with momentum is given by:

where , , and denote the velocity, momentum and step size, respectively. The structure we define stores the learning rate, the momentum factor, and an internal velocity buffer:

#![allow(unused)]
fn main() {
pub struct Momentum {
    pub learning_rate: f64,
    pub momentum: f64,
    pub velocity: Vec<f64>,
}
}

We define the constructor by taking the required parameters, and we initialize the velocity to a zero vector:

#![allow(unused)]
fn main() {
impl Momentum {
    /// Creates a new momentum optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size used to update weights.
    /// - `momentum`: Momentum coefficient (typically between 0.8 and 0.99).
    /// - `dim`: Dimension of the parameter vector, used to initialize velocity.
    pub fn new(learning_rate: f64, momentum: f64, dim: usize) -> Self {
        Self {
            learning_rate,
            momentum,
            velocity: vec![0.0; dim],
        }
    }
}
}

The step function is slightly more complex, as it performs elementwise operations over the weights, velocity, and gradients:

#![allow(unused)]
fn main() {
impl Optimizer for Momentum {
    /// Applies the momentum update step.
    ///
    /// Each step uses the update rule:
    /// ```text
    /// v ← momentum * v + learning_rate * grad
    /// w ← w - v
    /// ```
    fn step(&mut self, weights: &mut [f64], grads: &[f64]) {
        for ((w, g), v) in weights
            .iter_mut()
            .zip(grads.iter())
            .zip(self.velocity.iter_mut())
        {
            *v = self.momentum * *v + self.learning_rate * *g;
            *w -= *v;
        }
    }
}
}

The internal state of the velocity is updated as well, which is possible because we pass a mutable reference &self. At this point, we've defined two optimizers using structs and a shared trait. To complete the module, we define a training loop that uses any optimizer implementing the trait.