Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Adding tests

In order to test our optimizers, we propose to have a look at how to implement tests and run them.

How to write tests in Rust

Tests can be included in the same file as the code using the #[cfg(test)] module. Each test function is annotated with #[test]. Inside a test, you can use assert_eq!, assert!, or similar macros to validate expected behavior.

What we test

We implemented a few tests to check:

  • That the constructors return the expected variant with the correct parameters
  • That the step method modifies weights as expected
  • That repeated calls to step update the internal state correctly (e.g., momentum's velocity)
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gradient_descent_constructor() {
        let opt = Optimizer::gradient_descent(0.01);
        match opt {
            Optimizer::GradientDescent { learning_rate } => {
                assert_eq!(learning_rate, 0.01);
            }
            _ => panic!("Expected GradientDescent optimizer"),
        }
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Optimizer::momentum(0.01, 0.9, 10);
        match opt {
            Optimizer::Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                assert_eq!(learning_rate, 0.01);
                assert_eq!(momentum, 0.9);
                assert_eq!(velocity.len(), 10);
            }
            _ => panic!("Expected Momentum optimizer"),
        }
    }

    #[test]
    fn test_step_gradient_descent() {
        let mut opt = Optimizer::gradient_descent(0.1);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);

        assert_eq!(weights, vec![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_step_momentum() {
        let mut opt = Optimizer::momentum(0.1, 0.9, 3);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);
        assert_eq!(weights, vec![0.95, 1.95, 2.95]);

        opt.step(&mut weights, &grads);
        assert!(
            weights
                .iter()
                .zip(vec![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
}

Some notes:

  • This module is added in the same file where the optimizers are implemented.
  • The line use super::*; tells the compiler to import all the stuff available in the module.

How to run the tests

To run the tests from the command line, use:

cargo test

This will automatically find and execute all test functions in the project. You should see output like:

running 4 tests
test tests::test_gradient_descent_constructor ... ok
test tests::test_momentum_constructor ... ok
test tests::test_step_gradient_descent ... ok
test tests::test_step_momentum ... ok

If any test fails, Cargo will show which assertion failed and why.