Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

API and usage

We expose the training loop in lib.rs as the public API. The function run_optimization takes a generic Optimizer, a gradient function, an initial weight vector, and a maximum number of iterations.

#![allow(unused)]
fn main() {
pub mod optimizers;
use optimizers::Optimizer;

pub fn run_optimization<Opt: Optimizer>(
    optimizer: &mut Opt,
    weights: &mut [f64],
    grad_fn: impl Fn(&[f64]) -> Vec<f64>,
    num_steps: usize,
) {
    for _ in 0..num_steps {
        let grads = grad_fn(weights);
        optimizer.step(weights, &grads);
    }
}
}

Here, the trait bound <Opt: Optimizer> tells the compiler that type of the given optimizer must implement the trait Optimizer. This ensures that optimizer has the required step function.

Example of usage

Here’s a simple example where we minimize the function in :

#![allow(unused)]
fn main() {
use traits_based::optimizers::Momentum;

fn grad_fn(w: &[f64]) -> Vec<f64> {
    w.iter().map(|wi| 2.0 * wi.powi(2)).collect()
}

let n: usize = 10;
let mut weights = vec![1.0; n];
let mut optimizer = Momentum::new(0.01, 0.9, n);

run_optimization(&mut optimizer, &mut weights, grad_fn, 100);

// Final weights after optimization
println!("{:?}", weights); 
}

Some final reminders:

  • Both weights and optimizer must be mutable, because we perform in-place updates.
  • We pass mutable references into run_optimization, matching its function signature.
  • The example uses a closure-based gradient function, which you can easily replace.