Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

API and usage

Once you have implemented the optimizers and the step logic, it's time to expose a public API to run optimization from your crate's lib.rs. This typically involves defining a helper function like run_optimization.

Define run_optimization in lib.rs

Similarly to the trait-based implementation, we can define a function run_optimization that performs the optimization. However, here, Optimizer is a enum instead of a trait, hence we can't define a generic type and write <Opt: Optimizer> (see trait-based run_optimization function if you don't remember.). Instead, we simply pass a concrete mutable reference.

#![allow(unused)]
fn main() {
pub mod optimizers;

use optimizers::Optimizer;

/// Runs an optimization loop over multiple steps.
///
/// # Arguments
/// - `optimizer`: The optimizer to use (e.g., GradientDescent, Momentum).
/// - `weights`: Mutable reference to the weights vector to be optimized.
/// - `grad_fn`: A closure that computes the gradient given the current weights.
/// - `num_steps`: Number of iterations to run.
pub fn run_optimization(
    optimizer: &mut Optimizer,
    weights: &mut [f64],
    grad_fn: impl Fn(&[f64]) -> Vec<f64>,
    num_steps: usize,
) {
    for _ in 0..num_steps {
        let grads = grad_fn(weights);
        optimizer.step(weights, &grads);
    }
}
}

Example of usage

Here's a basic example of using run_optimization to minimize a simple quadratic loss.

fn main() {
    let grad_fn = |w: &[f64]| vec![2.0 * (w[0] - 3.0)];
    
    let mut weights = vec![0.0];
    let mut optimizer = Optimizer::gradient_descent(0.1);

    run_optimization(&mut optimizer, &mut weights, grad_fn, 100);

    println!("Optimized weight: {:?}", weights[0]);
}