Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Adding tests

Finally, we can add tests to check out ndarray-based implementation.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn test_gradient_descent_constructor() {
        let optimizer = GD::new(1e-3);
        assert_eq!(1e-3, optimizer.step_size);
    }

    #[test]
    fn test_step_gradient_descent() {
        let opt = GD::new(0.1);
        let mut weights = array![1.0, 2.0, 3.0];
        let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5];
        opt.run(&mut weights, grad_fn, 1);

        assert_eq!(weights, array![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Momentum::new(0.01, 0.9);
        assert_eq!(
            opt.step_size, 0.01,
            "Expected step size to be 0.01 but got {}",
            opt.step_size
        );
        assert_eq!(
            opt.momentum, 0.9,
            "Expected momentum to be 0.9 but got {}",
            opt.momentum
        );
    }

    #[test]
    fn test_step_momentum() {
        let opt = Momentum::new(0.1, 0.9);
        let mut weights = array![1.0, 2.0, 3.0];
        let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5];

        opt.run(&mut weights, grad_fn, 2);
        assert!(
            weights
                .iter()
                .zip(array![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
}