Gradient descent with momentum
The Momentum
struct extends gradient descent with a classical momentum term:
#![allow(unused)] fn main() { pub struct Momentum { step_size: f64, momentum: f64, } }
It has a constructor:
#![allow(unused)] fn main() { impl Momentum { pub fn new(step_size: f64, momentum: f64) -> Self { Self { step_size, momentum, } } } }
This algorithm adds momentum to classical gradient descent. Instead of updating weights using just the current gradient, it maintains a velocity vector that accumulates the influence of past gradients. This helps smooth the trajectory and accelerates convergence on convex problems.
#![allow(unused)] fn main() { impl Optimizer for Momentum { fn run( &self, weights: &mut Array1<f64>, grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>, n_steps: usize, ) { let n: usize = weights.len(); let mut velocity: Array1<f64> = Array::zeros(n); for _ in 0..n_steps { let grads = grad_fn(weights); for ((w, g), v) in weights .iter_mut() .zip(grads.iter()) .zip(velocity.iter_mut()) { *v = self.momentum * *v - self.step_size * g; *w += *v; } } } } }
Some notes:
-
We initialize a vector of zeros to track the momentum (velocity) across steps. It has the same length as the weights. This is achieved with:
let mut velocity: Array1<f64> = Array::zeros(n)
. Note that we could have defined the velocity as an internal state variable within the struct defintion. -
We use a triple nested zip to unpack the values of the weights, gradients, and velocity:
for ((w, g), v) in weights.iter_mut().zip(grads.iter()).zip(velocity.iter_mut())
. Here,weights.iter_mut()
gives a mutable reference to each weight,grads.iter()
provides read-only access to each gradient,velocity.iter_mut()
allows in-place updates of the velocity vector.
This pattern allows us to update everything in one pass, element-wise.
-
Within the nested zip closure, we update the velocity using the momentum term and current gradient:
*v = self.momentum * *v - self.step_size * g;
-
The weight is updated using the new velocity:
*w += *v;
. Again, we dereferencew
because it's a mutable reference.