Adding tests
Finally, we can add tests to check out ndarray-based implementation.
#![allow(unused)] fn main() { #[cfg(test)] mod tests { use super::*; use ndarray::array; #[test] fn test_gradient_descent_constructor() { let optimizer = GD::new(1e-3); assert_eq!(1e-3, optimizer.step_size); } #[test] fn test_step_gradient_descent() { let opt = GD::new(0.1); let mut weights = array![1.0, 2.0, 3.0]; let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5]; opt.run(&mut weights, grad_fn, 1); assert_eq!(weights, array![0.95, 1.95, 2.95]) } #[test] fn test_momentum_constructor() { let opt = Momentum::new(0.01, 0.9); assert_eq!( opt.step_size, 0.01, "Expected step size to be 0.01 but got {}", opt.step_size ); assert_eq!( opt.momentum, 0.9, "Expected momentum to be 0.9 but got {}", opt.momentum ); } #[test] fn test_step_momentum() { let opt = Momentum::new(0.1, 0.9); let mut weights = array![1.0, 2.0, 3.0]; let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5]; opt.run(&mut weights, grad_fn, 2); assert!( weights .iter() .zip(array![0.855, 1.855, 2.855]) .all(|(a, b)| (*a - b).abs() < 1e-6) ); } } }