API and usage
We expose the training loop in lib.rs
as the public API. The function run_optimization
takes a generic Optimizer
, a gradient function, an initial weight vector, and a maximum number of iterations.
#![allow(unused)] fn main() { pub mod optimizers; use optimizers::Optimizer; pub fn run_optimization<Opt: Optimizer>( optimizer: &mut Opt, weights: &mut [f64], grad_fn: impl Fn(&[f64]) -> Vec<f64>, num_steps: usize, ) { for _ in 0..num_steps { let grads = grad_fn(weights); optimizer.step(weights, &grads); } } }
Here, the trait bound <Opt: Optimizer>
tells the compiler that type of the given optimizer
must implement the trait Optimizer
. This ensures that optimizer
has the required step function.
Example of usage
Here’s a simple example where we minimize the function in :
#![allow(unused)] fn main() { use traits_based::optimizers::Momentum; fn grad_fn(w: &[f64]) -> Vec<f64> { w.iter().map(|wi| 2.0 * wi.powi(2)).collect() } let n: usize = 10; let mut weights = vec![1.0; n]; let mut optimizer = Momentum::new(0.01, 0.9, n); run_optimization(&mut optimizer, &mut weights, grad_fn, 100); // Final weights after optimization println!("{:?}", weights); }
Some final reminders:
- Both
weights
andoptimizer
must be mutable, because we perform in-place updates. - We pass mutable references into
run_optimization
, matching its function signature. - The example uses a closure-based gradient function, which you can easily replace.