Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Adding tests

We can easily adapt the tests we implemented for the trait-based version of the optimizers. Here, we rely on pattern matching to check the constructors.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gradient_descent_constructor() {
        let opt = Optimizer::gradient_descent(0.01);
        match opt {
            Optimizer::GradientDescent { learning_rate } => {
                assert_eq!(learning_rate, 0.01);
            }
            _ => panic!("Expected GradientDescent optimizer"),
        }
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Optimizer::momentum(0.01, 0.9, 10);
        match opt {
            Optimizer::Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                assert_eq!(learning_rate, 0.01);
                assert_eq!(momentum, 0.9);
                assert_eq!(velocity.len(), 10);
            }
            _ => panic!("Expected Momentum optimizer"),
        }
    }

    #[test]
    fn test_step_gradient_descent() {
        let mut opt = Optimizer::gradient_descent(0.1);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);

        assert_eq!(weights, vec![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_step_momentum() {
        let mut opt = Optimizer::momentum(0.1, 0.9, 3);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);
        assert_eq!(weights, vec![0.95, 1.95, 2.95]);

        opt.step(&mut weights, &grads);
        assert!(
            weights
                .iter()
                .zip(vec![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
}