Optimizers using traits
This chapter illustrates how to use traits for implementing a module of optimizers. This approach is useful when you want polymorphism or when each optimizer requires its own state and logic.
It's similar to what you might do in other languages such as Python or C++, and it's a good fit for applications that involve multiple algorithm variants.
Trait definition
We define a common trait Optimizer
, which describes the shared behavior of any optimizer. Let's assume that our optimizers only need a step
function.
#![allow(unused)] fn main() { /// A trait representing an optimization algorithm that can update weights using gradients. /// /// Optimizers must implement the `step` method, which modifies weights in place. pub trait Optimizer { /// Performs a single optimization step. /// /// # Arguments /// - `weights`: Mutable slice of parameters to be updated. /// - `grads`: Slice of gradients corresponding to the weights. fn step(&mut self, weights: &mut [f64], grads: &[f64]); } }
Any type that implements this trait must provide a step
method. Let's illustrate how to use this by implementing two optimizers: gradient descent with and without momentum.
Gradient descent
We first define the structure for the gradient descent algorithm. It only stores the learning rate as a f64
.
#![allow(unused)] fn main() { /// Basic gradient descent optimizer. /// /// Updates each weight by subtracting the gradient scaled by a learning rate. pub struct GradientDescent { pub learning_rate: f64, } }
We then implement a constructor. In this case, it simply consists of choosing the learning rate.
#![allow(unused)] fn main() { impl GradientDescent { /// Creates a new gradient descent optimizer. /// /// # Arguments /// - `learning_rate`: Step size used to update weights. pub fn new(learning_rate: f64) -> Self { Self { learning_rate } } } }
Next, we implement the step
method required by the Optimizer
trait:
#![allow(unused)] fn main() { impl Optimizer for GradientDescent { /// Applies the gradient descent step to each weight. /// /// Each weight is updated as: `w ← w - learning_rate * grad` fn step(&mut self, weights: &mut [f64], grads: &[f64]) { for (w, g) in weights.iter_mut().zip(grads.iter()) { *w -= self.learning_rate * g; } } } }
This function updates each entry of weights
by looping over the elements and applying the gradient descent update. We use elementwise operations because Vec
doesn't provide built-in arithmetic methods. External crates such as ndarray
or nalgebra
could help write this more expressively.
Gradient descent with momentum
Now let’s implement gradient descent with momentum. The structure stores the learning rate, the momentum factor, and an internal velocity buffer:
#![allow(unused)] fn main() { /// Momentum-based gradient descent optimizer. /// /// Combines current gradients with a velocity term to smooth updates. pub struct Momentum { pub learning_rate: f64, pub momentum: f64, pub velocity: Vec<f64>, } }
We define the constructor by taking the required parameters, and we initialize the velocity to a zero vector:
#![allow(unused)] fn main() { impl Momentum { /// Creates a new momentum optimizer. /// /// # Arguments /// - `learning_rate`: Step size used to update weights. /// - `momentum`: Momentum coefficient (typically between 0.8 and 0.99). /// - `dim`: Dimension of the parameter vector, used to initialize velocity. pub fn new(learning_rate: f64, momentum: f64, dim: usize) -> Self { Self { learning_rate, momentum, velocity: vec![0.0; dim], } } } }
The step
function is slightly more complex, as it performs elementwise operations over the weights, velocity, and gradients:
#![allow(unused)] fn main() { impl Optimizer for Momentum { /// Applies the momentum update step. /// /// Each step uses the update rule: /// ```text /// v ← momentum * v + learning_rate * grad /// w ← w - v /// ``` fn step(&mut self, weights: &mut [f64], grads: &[f64]) { for ((w, g), v) in weights .iter_mut() .zip(grads.iter()) .zip(self.velocity.iter_mut()) { *v = self.momentum * *v + self.learning_rate * *g; *w -= *v; } } } }
At this point, we've defined two optimizers using structs and a shared trait. To complete the module, we define a training loop that uses any optimizer implementing the trait.
Public API
We expose the training loop in lib.rs
as the public API. The function run_optimization
takes a generic Optimizer
, a gradient function, an initial weight vector, and a maximum number of iterations.
#![allow(unused)] fn main() { pub mod enum_based; pub mod traits_and_ndarray; pub mod traits_based; use traits_based::optimizers::Optimizer; pub fn run_optimization<O: Optimizer>( optimizer: &mut O, weights: &mut [f64], grad_fn: impl Fn(&[f64]) -> Vec<f64>, num_steps: usize, ) { for _ in 0..num_steps { let grads = grad_fn(weights); optimizer.step(weights, &grads); } } }
Example of usage
Here’s a simple example where we minimize the function in :
#![allow(unused)] fn main() { use traits_based::optimizers::Momentum; fn grad_fn(w: &[f64]) -> Vec<f64> { w.iter().map(|wi| 2.0 * wi.powi(2)).collect() } let n: usize = 10; let mut weights = vec![1.0; n]; let mut optimizer = Momentum::new(0.01, 0.9, n); run_optimization(&mut optimizer, &mut weights, grad_fn, 100); // Final weights after optimization println!("{:?}", weights); }
Some notes:
- Both
weights
andoptimizer
must be mutable, because we perform in-place updates. - We pass mutable references into
run_optimization
, matching its function signature. - The example uses a closure-based gradient function, which you can easily replace.