Adding tests
In order to test our optimizers, we propose to have a look at how to implement tests and run them.
How to write tests in Rust
Tests can be included in the same file as the code using the #[cfg(test)]
module. Each test function is annotated with #[test]
. Inside a test, you can use assert_eq!
, assert!
, or similar macros to validate expected behavior.
What we test
We implemented a few tests to check:
- That the constructors return the expected variant with the correct parameters
- That the
step
method modifies weights as expected - That repeated calls to
step
update the internal state correctly (e.g., momentum's velocity)
#![allow(unused)] fn main() { #[cfg(test)] mod tests { use super::*; #[test] fn test_gradient_descent_constructor() { let opt = Optimizer::gradient_descent(0.01); match opt { Optimizer::GradientDescent { learning_rate } => { assert_eq!(learning_rate, 0.01); } _ => panic!("Expected GradientDescent optimizer"), } } #[test] fn test_momentum_constructor() { let opt = Optimizer::momentum(0.01, 0.9, 10); match opt { Optimizer::Momentum { learning_rate, momentum, velocity, } => { assert_eq!(learning_rate, 0.01); assert_eq!(momentum, 0.9); assert_eq!(velocity.len(), 10); } _ => panic!("Expected Momentum optimizer"), } } #[test] fn test_step_gradient_descent() { let mut opt = Optimizer::gradient_descent(0.1); let mut weights = vec![1.0, 2.0, 3.0]; let grads = vec![0.5, 0.5, 0.5]; opt.step(&mut weights, &grads); assert_eq!(weights, vec![0.95, 1.95, 2.95]) } #[test] fn test_step_momentum() { let mut opt = Optimizer::momentum(0.1, 0.9, 3); let mut weights = vec![1.0, 2.0, 3.0]; let grads = vec![0.5, 0.5, 0.5]; opt.step(&mut weights, &grads); assert_eq!(weights, vec![0.95, 1.95, 2.95]); opt.step(&mut weights, &grads); assert!( weights .iter() .zip(vec![0.855, 1.855, 2.855]) .all(|(a, b)| (*a - b).abs() < 1e-6) ); } } }
Some notes:
- This module is added in the same file where the optimizers are implemented.
- The line
use super::*;
tells the compiler to import all the stuff available in the module.
How to run the tests
To run the tests from the command line, use:
cargo test
This will automatically find and execute all test functions in the project. You should see output like:
running 4 tests
test tests::test_gradient_descent_constructor ... ok
test tests::test_momentum_constructor ... ok
test tests::test_step_gradient_descent ... ok
test tests::test_step_momentum ... ok
If any test fails, Cargo will show which assertion failed and why.