Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Welcome!

Welcome to Rustineers, a dive into the Rust programming language through the lens of applied mathematics and science. There are already several high-quality resources available for learning Rust:

You can find even more learning material at rust-lang.org.

This book is meant to be complementary to those great resources. Our goal is to learn Rust by implementing practical examples drawn from applied mathematics, including:

  • Machine Learning
  • Statistics and Probability
  • Optimization
  • Ordinary Differential Equations (ODEs)
  • Partial Differential Equations (PDEs)
  • And other topics from engineering and physics

Each chapter centers around a specific scientific algorithm or computational problem. We explore how to implement it idiomatically in Rust and sometimes in multiple styles.

Hopefully, we manage to go through the core concepts of Rust, namely:

  • Ownership / borrowing
  • Data types
  • Traits
  • Modules
  • Error handling
  • Macros

Most examples being with Rust's standard library which seems to be a solid foundation for learning both the language and its ecosystem.

Difficulty Levels

To help you navigate the material, each chapter is marked with a difficulty level using 🦀 emojis:

  • 🦀 — Beginner-friendly
  • 🦀🦀 — Intermediate
  • 🦀🦀🦀 — Advanced

As this is a work in progress, the difficulty levels might not always be well chosen.

Roadmap

Here's an unordered list of examples of topics that could be added to the book:

  • 1D Ridge regression.
  • Simple first-order gradient descent algorithms.
  • Kernel methods: multivariate kernel Ridge regression.
  • Scientific computing: Solving the 2D Poisson problem with the finite element method.
  • Classification algorithms: logistic regression.
  • Some clustering algorithms: K-means, Gaussian mixtures.
  • Some MCMC algorithms: MH, LMC, MALA.
  • Numerical methods for solving ODEs: Euler, Runge-Kutta.
  • Optimization algorithms: gradient-based, derivative-free.
  • Divergences and distances for probability distributions: KL divergence, total variation, Wasserstein.

Let us know if you have other ideas or if you want to improve any existing chapter.

Cargo 101

This chapter gives you everything you need to compile and run the examples in this book using Cargo, Rust’s official package manager and build tool.

Creating a new project

To create a new Rust library project:

cargo new my_project --lib

To create a binary project (i.e., one with a main.rs):

cargo new my_project

Building your project

Navigate into the project directory:

cd my_project

Then build it:

cargo build

This compiles your code in debug mode (faster builds, less optimization). You’ll find the output in target/debug/.

Running your code

If it’s a binary crate (with a main.rs), you can run it:

cargo run

This compiles and runs your code in one go.

Testing your code

To run tests in lib.rs or in any #[cfg(test)] module:

cargo test

Cleaning build artifacts

Remove the target/ directory and everything in it:

cargo clean

Checking your code (without building)

cargo check

This quickly verifies your code compiles without generating the binary.

Adding dependencies

To add dependencies, open Cargo.toml and add them under [dependencies]:

[dependencies]
ndarray = "0.15"

Or use the command line:

cargo add ndarray

Ridge regression 1D

Here, we implement one-dimensional Ridge Regression in several styles, using increasing levels of abstraction. It's designed as a learning path for beginners, and focuses on writing idiomatic, clear, and type-safe Rust code. We focus on minimizing this loss function:

where: is an input covariate, is the associated output, is the Ridge coefficient, is the regularization strength.

How this chapter is organized

This chapter introduces several useful concepts for Rust beginners. It is divided into four sections, each solving the same problem (1D Ridge regression) using different tools and with slightly increasing complexity.

  • The first section shows how to use basic functions and the Rust standard library to build a simple library. In particular, it shows how to manipulate vectors (Vec<f64>) and slices (&[f64]).

  • The next section explains how to solve the same problem using structs and traits to make the code more modular and extensible.

  • The third section introduces generics, allowing the code to work with different floating-point types (f32 and f64).

  • Finally, the last section goes further by using ndarray for linear algebra and incorporating additional Rust features such as optional values, pattern matching, and error handling.

If you want to implement and run this example while you read but are not familiar with Cargo yet, have a look at Cargo 101 for how to set up your project.

Functional: introduction

This section focuses on implementinng the 1D Ridge problem using functions and Rust standard library only. It's divided into 5 subsections:

  1. Loss function: Shows how to implement the Ridge loss function in two simple ways.
  2. Closed-form solution: Implements the closed-form solution of the Ridge optimization problem.
  3. Gradient-descent: Solves the Ridge problem using gradient descent to illustrate how to perform for loops.
  4. Putting things together: Explains how to assemble everything into a simple library.
  5. Exposing API: Explains how to use lib.rs to define what is made available to the user.

What we're building here

The aim of this chapter is to build a small crate with the following layout:

crates/ridge_1d_fn/
├── Cargo.toml
└── src
    ├── estimator.rs           # Closed-form solution of the Ridge estimator
    ├── gradient_descent.rs    # Gradient descent solution
    ├── lib.rs                 # Main entry point for the library
    └── loss_functions.rs      # Loss function implementations

It is made of three modules: estimator.rs, gradient_descent.rs, and loss_fnctions.rs. At the end of the chapter, we end up with a crate that can be used as follows:

use ridge_1d_fn::{fit, predict};

fn main() {
    let x = vec![1.0, 2.0, 3.0];
    let y = vec![2.0, 4.0, 6.0];

    let beta = fit(&x, &y, 0.1, 0.01, 1000, 0.0);
    let preds = predict(&x, beta);

    println!("Learned beta: {}", beta);
    println!("Predictions: {:?}", preds);
}

The fit and predict functions are implemented in the library entry point lib.rs.

Note that the gradient_descent.rs, and loss_fnctions.rs modules mostly serve as additional illustrations.

What's next

After this first chapter, we explore how to implement the same things using structs and traits to make our code more modular.

Ridge loss

In this example, we implement one-dimensional Ridge Regression loss using only the Rust standard library, without any external crates. This lets us focus on core Rust features such as slices, iterators, and type safety.

Although the loss function by itself isn't really useful to solve the Ridge problem, implementing it provides a simple and focused introduction to Rust.

Naive implementation

We now present a straightforward implementation of the Ridge regression loss function:

#![allow(unused)]
fn main() {
pub fn loss_function_naive(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    assert_eq!(x.len(), y.len(), "x and y must have the same length");

    let n: usize = x.len();
    let y_hat: Vec<f64> = mul_scalar_vec(beta, x);
    let residuals: Vec<f64> = subtract_vectors(y, &y_hat);
    let mse: f64 = residuals.iter().map(|x| x * x).sum::<f64>() / (n as f64);
    mse + lambda2 * beta * beta
}
}

In this example, we use two helper functions that we implement ourselves. A helper function for multiplying a vector by a scalar:

#![allow(unused)]
fn main() {
pub fn mul_scalar_vec(scalar: f64, vector: &[f64]) -> Vec<f64> {
    vector.iter().map(|x| x * scalar).collect()
}
}

We also defined a helper that subtracts two slices element-wise:

#![allow(unused)]
fn main() {
pub fn subtract_vectors(a: &[f64], b: &[f64]) -> Vec<f64> {
    assert_eq!(a.len(), b.len(), "Input vectors must have the same length");
    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
}
}

Rather than using explicit loops, this implementation uses Rust’s iterator combinators, which the compiler optimizes into efficient code. This zero-cost abstraction keeps the code both readable and fast.

Ownership and borrowing

In Rust, every value has a single owner. When you assign a value to a new variable or pass it to a function by value, ownership is transferred (moved).

Borrowing allows you to use a value without taking ownership of it. Borrowing is done using references:

  • &T is a shared (read-only) reference.
  • &mut T is a mutable reference.

These references allow access to data without moving it.

A function like this:

#![allow(unused)]
fn main() {
fn mul_scalar_vec(scalar: f64, vector: &[f64]) -> Vec<f64> {
    vector.iter().map(|x| x * scalar).collect()
}
}

does not take ownership of the input vector. Instead, it borrows it for the duration of the function call. This makes it easier to reuse the input vector later.

If we instead defined:

#![allow(unused)]
fn main() {
fn mul_scalar_vec(scalar: f64, vector: Vec<f64>) -> Vec<f64> { ... }
}

then passing a vector would move ownership:

#![allow(unused)]
fn main() {
let v = vec![1.0, 2.0, 3.0];
let result = mul_scalar_vec(2.0, v); // v is moved here
let v2 = v; // error: value borrowed after move
}

Why use &[f64] instead of Vec<f64>?

The type &[f64] seems to be commonly used in function signatures because it works with both arrays and vectors.

Finally, note that:

  • Vec<f64> is an owned, growable vector on the heap. The only time we return a Vec<f64> is when we allocate a new output vector, like in mul_scalar_vec.
  • &Vec<f64> is a shared reference to a Vec<f64>.
  • &[f64] is a slice, i.e., a borrowed view into an array or vector.

In this chapter, we will mostly use these types but things can easily get more tricky.

Inlined iterator-based implementation

Let's implement the loss function in a more compact way. Instead of breaking the computation into multiple intermediate steps—like computing y_hat, residuals, and then squaring each residual—here we inline all computations into a single expression using iterators and closures.

This is ideal for demonstrating the expressive power of Rust's iterator API, especially once you're comfortable with basic slice handling and .map() chaining.

#![allow(unused)]
fn main() {
pub fn loss_function_inline(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    let n: usize = y.len();
    let factor = n as f64;
    let mean_squared_error = x
        .iter()
        .zip(y.iter())
        .map(|(xi, yi)| {
            let residual = yi - beta * xi;
            residual * residual
        })
        .sum::<f64>()
        / factor;
    mean_squared_error + lambda2 * beta * beta
}
}

This implementation computes the mean squared error in a single iteration, minimizing allocations and abstraction overhead. In particular:

  • We use .iter().zip() to iterate over two slices.
  • We define a full code block inside the .map() closure, which makes it easier to write intermediate expressions like let residual = yi - beta * xi; before returning the squared value.

Closed-form solution

The one-dimensional Ridge regression problem admits a simple closed-form solution.
Given a dataset for , and a regularization parameter , the Ridge estimator is:

This form assumes that the data has no intercept term, i.e., the model is , or equivalently, that the data is centered around zero. In practice, it is common to subtract the means of both and before computing the estimator. This removes the intercept and gives:

We now implement this solution in Rust, using only the standard library.

#![allow(unused)]
fn main() {
pub fn ridge_estimator(x: &[f64], y: &[f64], lambda2: f64) -> f64 {
    let n: usize = x.len();
    assert_eq!(n, y.len(), "x and y must have the same length");

    let x_mean: f64 = x.iter().sum::<f64>() / n as f64;
    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;

    let num: f64 = x
        .iter()
        .zip(y)
        .map(|(xi, yi)| (xi - x_mean) * (yi - y_mean))
        .sum::<f64>();

    let denom: f64 = x.iter().map(|xi| (xi - x_mean).powi(2)).sum::<f64>() + lambda2 * (n as f64);

    num / denom
}
}

You could also express it as a single iterator chain, similar to how we implemented the loss function earlier.

Gradient descent

As an exercise, we now implement the gradient descent algorithm to optimize the Ridge regression loss. The gradient has a closed-form expression and can be efficiently computed. We implement it in two different ways as we did for the loss function.

Gradient descent implementation

Gradient descent iteratively updates the parameter β using the gradient of the loss function:

Where η is the learning rate, and ∇βL(β) is the gradient of the loss.

We allow flexible experimentation by passing the gradient function as parameters:

#![allow(unused)]
fn main() {
pub fn ridge_estimator(
    grad_fn: impl Fn(&[f64], &[f64], f64, f64) -> f64,
    x: &[f64],
    y: &[f64],
    lambda2: f64,
    lr: f64,
    n_iters: usize,
    init_beta: f64,
) -> f64 {
    let mut beta = init_beta;

    for _ in 0..n_iters {
        let grad = grad_fn(x, y, beta, lambda2);
        beta -= lr * grad;
    }

    beta
}
}

This version is generic, letting us plug in any valid grad_fn.

Gradient function: naive implementation

This version breaks the computation into two separate steps:

  • Compute the residuals
  • Compute the dot product between the residuals and the inputs:
  • Then assemble the get the gradient value

We first start by implementing our own dot function by relying on iterators, map chaining, and summing the results.

#![allow(unused)]
fn main() {
pub fn dot(a: &[f64], b: &[f64]) -> f64 {
    assert_eq!(a.len(), b.len(), "Input vectors must have the same length");
    a.iter().zip(b.iter()).map(|(xi, yi)| xi * yi).sum()
}

}

Our first implementation takes the following form:

#![allow(unused)]
fn main() {
pub fn grad_loss_function_naive(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    assert_eq!(x.len(), y.len(), "x and y must have the same length");

    let n: usize = x.len();
    let residuals: Vec<f64> = x
        .iter()
        .zip(y.iter())
        .map(|(xi, yi)| yi - beta * xi)
        .collect();
    let residuals_dot_x = dot(&residuals, x);

    -2.0 * residuals_dot_x / (n as f64) + 2.0 * lambda2 * beta
}
}

Gradient function: inlined iterator-based implementation

In this version, we fuse the residual and gradient computation into a single iterator chain. This avoids intermediate memory allocations and takes full advantage of Rust’s zero-cost abstraction model.

#![allow(unused)]
fn main() {
pub fn grad_loss_function_inline(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    assert_eq!(x.len(), y.len(), "x and y must have the same length");

    let n: usize = x.len();
    let grad_mse: f64 = x
        .iter()
        .zip(y.iter())
        .map(|(xi, yi)| 2.0 * (yi - beta * xi) * xi)
        .sum::<f64>()
        / (n as f64);

    -grad_mse + 2.0 * lambda2 * beta
}
}

Key differences:

  • The naive version allocates a temporary vectosr for the residuals and the dot product.
  • The inlined version is more idiomatic Rust: it avoids allocation and achieves better performance through iterator fusion.

Putting things together

To wrap up our 1D Ridge Regression example, let's see how all the parts fit together into a real Rust crate.

Project layout

Here’s the directory structure for our ridge_1d_fn crate:

crates/ridge_1d_fn/
├── Cargo.toml
└── src
    ├── estimator.rs           # Closed-form solution of the Ridge estimator
    ├── gradient_descent.rs    # Gradient descent solution
    ├── lib.rs                 # Main entry point for the library
    └── loss_functions.rs      # Loss function implementations

All the functions discussed in the previous sections are implemented in estimator.rs, loss_functions.rs, gradient_descent.rs. You can inspect each of these files below.

Click to view estimator.rs
#![allow(unused)]
fn main() {
/// Computes the one-dimensional Ridge regression estimator using centered data.
///
/// This version centers the input data `x` and `y` before applying the closed-form formula.
///
/// # Arguments
///
/// * `x` - A slice of input features.
/// * `y` - A slice of target values (same length as `x`).
/// * `lambda2` - The regularization parameter.
///
/// # Returns
///
/// * `f64` - The estimated Ridge regression coefficient.
///
/// # Panics
///
/// Panics if `x` and `y` do not have the same length.
// ANCHOR: ridge_estimator
pub fn ridge_estimator(x: &[f64], y: &[f64], lambda2: f64) -> f64 {
    let n: usize = x.len();
    assert_eq!(n, y.len(), "x and y must have the same length");

    let x_mean: f64 = x.iter().sum::<f64>() / n as f64;
    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;

    let num: f64 = x
        .iter()
        .zip(y)
        .map(|(xi, yi)| (xi - x_mean) * (yi - y_mean))
        .sum::<f64>();

    let denom: f64 = x.iter().map(|xi| (xi - x_mean).powi(2)).sum::<f64>() + lambda2 * (n as f64);

    num / denom
}
// ANCHOR_END: ridge_estimator

// ANCHOR: tests
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ridge_estimator() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let true_beta: f64 = 0.1;
        let lambda2: f64 = 0.0;

        let beta_estimate: f64 = ridge_estimator(&x, &y, lambda2);
        assert!(
            (true_beta - beta_estimate).abs() < 1e-6,
            "Estimate {} not close enough to true solution {}",
            beta_estimate,
            true_beta
        );
    }
}
// ANCHOR_END: tests
}
Click to view gradient_descent.rs
#![allow(unused)]
fn main() {
/// Dot product between two vectors.
///
/// # Arguments
/// * `a` - First input vector
/// * `b` - Second input vector
///
/// # Returns
///
/// The float value of the dot product.
///
/// # Panics
///
/// Panics if `a` and `b` do have the same length.
// ANCHOR: dot
pub fn dot(a: &[f64], b: &[f64]) -> f64 {
    assert_eq!(a.len(), b.len(), "Input vectors must have the same length");
    a.iter().zip(b.iter()).map(|(xi, yi)| xi * yi).sum()
}

// ANCHOR_END: dot
/// Computes the gradient of the Ridge regression loss function (naive version).
///
/// This implementation first explicitly computes the residuals and then performs
/// a dot product between the residuals and the inputs.
///
/// # Arguments
///
/// * `x` - Slice of input features
/// * `y` - Slice of target outputs
/// * `beta` - Coefficient of the regression model
/// * `lambda2` - L2 regularization strength
///
/// # Returns
///
/// The gradient of the loss with respect to `beta`.
///
/// # Panics
///
/// Panics if `x` and `y` do not have the same length.
// ANCHOR: grad_loss_function_naive
pub fn grad_loss_function_naive(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    assert_eq!(x.len(), y.len(), "x and y must have the same length");

    let n: usize = x.len();
    let residuals: Vec<f64> = x
        .iter()
        .zip(y.iter())
        .map(|(xi, yi)| yi - beta * xi)
        .collect();
    let residuals_dot_x = dot(&residuals, x);

    -2.0 * residuals_dot_x / (n as f64) + 2.0 * lambda2 * beta
}
// ANCHOR_END: grad_loss_function_naive

/// Computes the gradient of the Ridge regression loss function (inlined version).
///
/// This version fuses the residual and gradient computation into a single pass
/// using iterators, minimizing allocations and improving efficiency.
///
/// # Arguments
///
/// * `x` - Slice of input features
/// * `y` - Slice of target outputs
/// * `beta` - Coefficient of the regression model
/// * `lambda2` - L2 regularization strength
///
/// # Returns
///
/// The gradient of the loss with respect to `beta`.
///
/// # Panics
///
/// Panics if `x` and `y` do not have the same length.
// ANCHOR: grad_loss_function_inline
pub fn grad_loss_function_inline(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    assert_eq!(x.len(), y.len(), "x and y must have the same length");

    let n: usize = x.len();
    let grad_mse: f64 = x
        .iter()
        .zip(y.iter())
        .map(|(xi, yi)| 2.0 * (yi - beta * xi) * xi)
        .sum::<f64>()
        / (n as f64);

    -grad_mse + 2.0 * lambda2 * beta
}
// ANCHOR_END: grad_loss_function_inline

/// Performs gradient descent to minimize the Ridge regression loss function.
///
/// # Arguments
///
/// * `grad_fn` - A function that computes the gradient of the Ridge loss
/// * `x` - Input features as a slice (`&[f64]`)
/// * `y` - Target values as a slice (`&[f64]`)
/// * `lambda2` - Regularization parameter
/// * `lr` - Learning rate
/// * `n_iters` - Number of gradient descent iterations
/// * `init_beta` - Initial value of the regression coefficient
///
/// # Returns
///
/// The optimized regression coefficient `beta` after `n_iters` updates
// ANCHOR: gradient_descent_estimator
pub fn ridge_estimator(
    grad_fn: impl Fn(&[f64], &[f64], f64, f64) -> f64,
    x: &[f64],
    y: &[f64],
    lambda2: f64,
    lr: f64,
    n_iters: usize,
    init_beta: f64,
) -> f64 {
    let mut beta = init_beta;

    for _ in 0..n_iters {
        let grad = grad_fn(x, y, beta, lambda2);
        beta -= lr * grad;
    }

    beta
}
// ANCHOR_END: gradient_descent_estimator

// ANCHOR: tests
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_grad_naive() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let beta: f64 = 0.1;
        let lambda2: f64 = 1.0;

        let grad = grad_loss_function_naive(&x, &y, beta, lambda2);
        let expected_grad = 0.2;
        let tol = 1e-6;
        assert!(
            (grad - expected_grad).abs() < tol,
            "Expected {}, got {}",
            expected_grad,
            grad
        );
    }

    #[test]
    fn test_grad_inline() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let beta: f64 = 0.1;
        let lambda2: f64 = 1.0;

        let grad = grad_loss_function_inline(&x, &y, beta, lambda2);
        let expected_grad = 0.2;
        let tol = 1e-6;
        assert!(
            (grad - expected_grad).abs() < tol,
            "Expected {}, got {}",
            expected_grad,
            grad
        );
    }

    #[test]
    fn test_naive_vs_inline() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let beta: f64 = 0.1;
        let lambda2: f64 = 1.0;

        let grad1 = grad_loss_function_inline(&x, &y, beta, lambda2);
        let grad2 = grad_loss_function_naive(&x, &y, beta, lambda2);
        assert_eq!(grad1, grad2);
    }
}
// ANCHOR_END: tests
}
Click to view loss_functions.rs
#![allow(unused)]
fn main() {
/// Multiplies a vector by a scalar.
///
/// # Arguments
///
/// * `scalar` - A scalar multiplier
/// * `vector` - A slice of f64 values
///
/// # Returns
///
/// A new vector containing the result of element-wise multiplication
///
/// # Why `&[f64]` instead of `Vec<f64]`?
///
/// We use a slice (`&[f64]`) because:
/// - It's more general: works with both arrays and vectors
/// - It avoids unnecessary ownership
/// - It's idiomatic and Clippy-compliant
// ANCHOR: mul_scalar_vec
pub fn mul_scalar_vec(scalar: f64, vector: &[f64]) -> Vec<f64> {
    vector.iter().map(|x| x * scalar).collect()
}
// ANCHOR_END: mul_scalar_vec

/// Subtracts two vectors element-wise.
///
/// # Arguments
///
/// * `a` - First input slice
/// * `b` - Second input slice
///
/// # Returns
///
/// A new `Vec<f64>` containing the element-wise difference `a[i] - b[i]`.
///
/// # Panics
///
/// Panics if `a` and `b` do not have the same length.
// ANCHOR: subtract_vectors
pub fn subtract_vectors(a: &[f64], b: &[f64]) -> Vec<f64> {
    assert_eq!(a.len(), b.len(), "Input vectors must have the same length");
    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
}
// ANCHOR_END: subtract_vectors

/// Computes the loss function for Ridge regression (naive version).
///
/// It implements it in a simple fashion by computing the mean squared error in multiple steps.
///
/// # Arguments
///
/// * `x` - The array of input observations
/// * `y` - The array of output observations
/// * `beta` - The coefficients of the linear regression
/// * `lambda2` - The regularization parameter
///
/// # Returns
///
/// The value of the loss function
/// Computes the Ridge regression loss function.
///
/// This function calculates the following expression:
///
/// $$
/// \mathcal{L}(\beta) = \frac{1}{2n} \sum_i (y_i - \beta x_i)^2 + \lambda \beta^2
/// $$
///
/// where:
/// - `x` and `y` are the input/output observations,
/// - `beta` is the linear coefficient,
/// - `lambda2` is the regularization strength.
///
/// # Arguments
///
/// * `x` - Input features as a slice (`&[f64]`)
/// * `y` - Target values as a slice (`&[f64]`)
/// * `beta` - Coefficient of the regression model
/// * `lambda2` - L2 regularization strength
///
/// # Returns
///
/// The Ridge regression loss value as `f64`.
///
/// # Panics
///
/// Panics if `x` and `y` do not have the same length.
// ANCHOR: loss_function_naive
pub fn loss_function_naive(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    assert_eq!(x.len(), y.len(), "x and y must have the same length");

    let n: usize = x.len();
    let y_hat: Vec<f64> = mul_scalar_vec(beta, x);
    let residuals: Vec<f64> = subtract_vectors(y, &y_hat);
    let mse: f64 = residuals.iter().map(|x| x * x).sum::<f64>() / (n as f64);
    mse + lambda2 * beta * beta
}
// ANCHOR_END: loss_function_naive

/// Computes the loss function for Ridge regression (inlined version).
///
/// It implements it as a one-liner by computing the mean squared error in a single expression.
///
/// # Arguments
///
/// * `x` - The array of input observations
/// * `y` - The array of output observations
/// * `beta` - The coefficients of the linear regression
/// * `lambda2` - The regularization parameter
///
/// # Returns
///
/// The value of the loss function
// ANCHOR: loss_function_line
pub fn loss_function_inline(x: &[f64], y: &[f64], beta: f64, lambda2: f64) -> f64 {
    let n: usize = y.len();
    let factor = n as f64;
    let mean_squared_error = x
        .iter()
        .zip(y.iter())
        .map(|(xi, yi)| {
            let residual = yi - beta * xi;
            residual * residual
        })
        .sum::<f64>()
        / factor;
    mean_squared_error + lambda2 * beta * beta
}
// ANCHOR_END: loss_function_line

// ANCHOR: tests
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_loss_function_naive() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let beta: f64 = 0.1;
        let lambda2: f64 = 1.0;

        let val: f64 = loss_function_naive(&x, &y, beta, lambda2);
        assert!(val > 0.0);
    }

    #[test]
    fn test_loss_function_line() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let beta: f64 = 0.1;
        let lambda2: f64 = 1.0;

        let val: f64 = loss_function_inline(&x, &y, beta, lambda2);
        assert!(val > 0.0);
    }

    #[test]
    fn test_naive_vs_inline() {
        let x: Vec<f64> = vec![1.0, 2.0];
        let y: Vec<f64> = vec![0.1, 0.2];
        let beta: f64 = 0.1;
        let lambda2: f64 = 1.0;

        let val1 = loss_function_naive(&x, &y, beta, lambda2);
        let val2 = loss_function_inline(&x, &y, beta, lambda2);
        assert_eq!(val1, val2);
    }
}
// ANCHOR_END: tests
}
Click to view lib.rs
#![allow(unused)]
fn main() {
// ANCHOR: lib_rs
pub mod estimator;
pub mod gradient_descent;
pub mod loss_functions;

pub use estimator::ridge_estimator;

/// Fits a Ridge regression model.
///
/// # Arguments
///
/// * `x` - Input features (`&[f64]`)
/// * `y` - Target values (`&[f64]`)
/// * `lambda2` - Regularization strength
///
/// # Returns
///
/// The optimized coefficient `beta` as `f64`.
pub fn fit(x: &[f64], y: &[f64], lambda2: f64) -> f64 {
    ridge_estimator(x, y, lambda2)
}

/// Predicts output values using a trained Ridge regression coefficient.
///
/// # Arguments
///
/// * `x` - Input features (`&[f64]`)
/// * `beta` - Trained coefficient
///
/// # Returns
///
/// A `Vec<f64>` with predicted values.
pub fn predict(x: &[f64], beta: f64) -> Vec<f64> {
    x.iter().map(|xi| xi * beta).collect()
}
// ANCHOR_END: lib_rs

// ANCHOR: run_demo
pub fn run_demo() {
    println!("-----------------------------------------------------");
    println!("Running ridge_1d_fn::run_demo");
    let x: Vec<f64> = vec![1.0, 2.0];
    let y: Vec<f64> = vec![0.1, 0.2];
    let lambda2 = 0.001;

    let beta = fit(&x, &y, lambda2);
    let preds = predict(&x, beta);

    println!("Learned beta: {beta}, true solution: 0.1!");
    println!("Predictions: {preds:?}");
    println!("-----------------------------------------------------");
}
// ANCHOR_END: run_demo
}

Note that the layout can be more complicated by introducing modules and submodules. This will be covered in the next chapter when we implement a structured-oriented version of the 1D Ridge regression.

What's lib.rs?

The lib.rs file is the entry point for the crate as a library. This is where we declare which modules (i.e., other .rs files) are exposed to the outside world.

#![allow(unused)]
fn main() {
pub mod estimator;
pub mod gradient_descent;
pub mod loss_functions;

pub use estimator::ridge_estimator;
}

Each line tells Rust:

“There is a file called X.rs that defines a module X. Please include it in the crate.”

By default, items inside a module are private. That’s where pub comes in.

We will dive deeper into lib.rs in the 2.1.5 Exposing API chapter.

Why pub?

If you want to use a function from another module or crate, you must declare it pub (public). For example:

#![allow(unused)]
fn main() {
// In utils.rs
pub fn dot(a: &[f64], b: &[f64]) -> f64 { ... }
}

If dot is not marked as pub, you can’t use it outside utils.rs, even from optimizer.rs.

Importing between modules

Rust requires explicit imports between modules. For example, let's say we want to use the dot function from gradient_descent.rs. We can import it as follows:

#![allow(unused)]
fn main() {
use crate::utils::dot;
}

Here, crate refers to the root of this library crate lib.rs.

Example of usage

Now let's see how you could use the library from a binary crate:

#![allow(unused)]
fn main() {
use ridge_1d_fn::ridge_estimator;

let x: Vec<f64> = vec![1.0, 2.0];
let y: Vec<f64> = vec![0.1, 0.2];
let lambda2 = 0.001;

let beta = ridge_estimator(&x, &y, lambda2);
}

Exposing a clean API

Until now, we've manually chained together the loss, gradient, and optimization steps. This is great for learning, but in real projects, we often want a simplified and reusable API.

Rust gives us a clean way to do this by leveraging the lib.rs file as the public interface to our crate.

lib.rs as a public API

In your crate, lib.rs is responsible for organizing and exposing the components we want users to interact with.

We can re-export key functions and define top-level utilities like fit and predict. The complete lib.rs file now looks like this:

#![allow(unused)]
fn main() {
pub mod estimator;
pub mod gradient_descent;
pub mod loss_functions;

pub use estimator::ridge_estimator;

/// Fits a Ridge regression model.
///
/// # Arguments
///
/// * `x` - Input features (`&[f64]`)
/// * `y` - Target values (`&[f64]`)
/// * `lambda2` - Regularization strength
///
/// # Returns
///
/// The optimized coefficient `beta` as `f64`.
pub fn fit(x: &[f64], y: &[f64], lambda2: f64) -> f64 {
    ridge_estimator(x, y, lambda2)
}

/// Predicts output values using a trained Ridge regression coefficient.
///
/// # Arguments
///
/// * `x` - Input features (`&[f64]`)
/// * `beta` - Trained coefficient
///
/// # Returns
///
/// A `Vec<f64>` with predicted values.
pub fn predict(x: &[f64], beta: f64) -> Vec<f64> {
    x.iter().map(|xi| xi * beta).collect()
}
}

Everything declared pub is available to the user. For simplicity, we decided to only expose the closed-form Ridge estimator.

Example of usage

You can update your binary entry point to try out the public API.

use ridge_1d_fn::{fit, predict};

fn main() {
    let x = vec![1.0, 2.0, 3.0];
    let y = vec![2.0, 4.0, 6.0];

    let beta = fit(&x, &y, 0.1, 0.01, 1000, 0.0);
    let preds = predict(&x, beta);

    println!("Learned beta: {}", beta);
    println!("Predictions: {:?}", preds);
}

Structured: introduction

This section focuses on implementinng the 1D Ridge problem using functions, structures, traits and Rust standard library only. It's divided into 3 subsections:

  1. Closed-form solution: Implements the closed-form solution of the Ridge optimization problem using a struct to define a RidgeEstimator type. It shows how to implement a constructor together with fit and predict functions.
  2. Gradient descent: Solves the Ridge problem using gradient descent using a struct as well to define a RidgeGradientDescent type.
  3. Trait Ridge model: Explains how to define a trait RidgeModel, which describes the shared behavior of any Ridge estimator like our RidgeEstimator and RidgeGradientDescent.

What we're building here

The aim of this chapter is to build a small crate with the following layout:

crates/ridge_1d_struct/
├── Cargo.toml
└── src
    ├── regressor.rs    # Closed-form solution of the Ridge estimator
    └── lib.rs          # Main entry point for the library

It is made of a single module regressor.rs which implements both the closed-form Ridge estimator and the gradient descent-based estimator using structs, respectively called RidgeEstimator and RidgeGradientDescent.

These types are exposed in the library entry point lib.rs. We end up with the following user interfaces.

Closed-form estimator:

#![allow(unused)]
fn main() {
use ridge_1d_struct::RidgeEstimator;

let mut model: RidgeEstimator = RidgeEstimator::new(0.0);

let x: Vec<f64> = vec![1.0, 2.0];
let y: Vec<f64> = vec![0.1, 0.2];
let lambda2 = 0.001;

model.fit(&x, &y, lambda2);
let preds = model.predict(&x);
}

Gradient descent-based estimator:

#![allow(unused)]
fn main() {
use ridge_1d_struct::RidgeGradientDescent;

let mut model: RidgeGradientDescent = RidgeGradientDescent::new(0.0, 1000, 1e-2);

let x: Vec<f64> = vec![1.0, 2.0];
let y: Vec<f64> = vec![0.1, 0.2];
let lambda2 = 0.001;

model.fit(&x, &y, lambda2);
let preds = model.predict(&x);
}

What's next

Up to this stage, we implemented everything using the f64 precision for all our variables. In the next section, we will see how to make our code independent of the floating-point types by leveraging generics.

Closed-form solution

We now present a structured implementation of the 1D Ridge estimator using a dedicated RidgeEstimator struct. We implement the same methods, i.e.,

  • fit: the method to compute the optimal from data and the regularization parameter ,
  • predict: the method to compute predictions from new data,

but rely on Rust's struct and impl to define a new type. We also an additional method new, a constructor to initialize the estimator with an initial value of .

Struct definition

This simple struct stores the estimated coefficient as a field.

#![allow(unused)]
fn main() {
pub struct RidgeEstimator {
    beta: f64,
}
}

Constructor and methods

Once the struct is defined, we can implement the constructor new, and the methods fit and predict.

#![allow(unused)]
fn main() {
impl RidgeEstimator {
    pub fn new(init_beta: f64) -> Self {
        Self { beta: init_beta }
    }

    fn fit(&mut self, x: &[f64], y: &[f64], lambda2: f64) {
        let n: usize = x.len();
        assert_eq!(n, y.len(), "x and y must have the same length");

        let x_mean: f64 = x.iter().sum::<f64>() / n as f64;
        let y_mean: f64 = y.iter().sum::<f64>() / n as f64;

        let num: f64 = x
            .iter()
            .zip(y)
            .map(|(xi, yi)| (xi - x_mean) * (yi - y_mean))
            .sum::<f64>();

        let denom: f64 =
            x.iter().map(|xi| (xi - x_mean).powi(2)).sum::<f64>() + lambda2 * (n as f64);

        self.beta = num / denom;
    }

    fn predict(&self, x: &[f64]) -> Vec<f64> {
        x.iter().map(|xi| self.beta * xi).collect()
    }
}
}

Note that we can decompose the implementation into as many blocks as we want:

#![allow(unused)]
fn main() {
impl RidgeEstimator {
    pub fn new(init_beta: f64) -> Self {
        Self { beta: init_beta }
    }
}

impl RidgeEstimator {
    fn fit(&mut self, x: &[f64], y: &[f64], lambda2: f64) {
        ...
    }
}

impl RidgeEstimator {
    fn predict(&self, x: &[f64]) -> Vec<f64> {
        ...
    }
}
}

This can be useful when dealing with complex methods.

Example of usage

Here is how we can use our new Ridge estimator:

fn main() {
    let x = vec![1.0, 2.0, 3.0, 4.0];
    let y = vec![2.1, 4.1, 6.2, 8.3];
    let lambda = 0.1;

    let mut model = RidgeEstimator::new(0.0);
    model.fit(&x, &y, lambda);

    let predictions = model.predict(&x);
    println!("Predictions: {:?}", predictions);
}

Gradient descent

As an another illustration of struct and impl, let's tackle the gradient descent method for the Ridge regression again. We use the following structure:

#![allow(unused)]
fn main() {
pub struct RidgeGradientDescent {
    beta: f64,
    n_iters: usize,
    lr: f64,
}
}

This struct stores the current coefficient , the number of iterations to run, and the learning rate. We can subsequently implement the constructor and all the methods required to perform gradient descent.

#![allow(unused)]
fn main() {
impl RidgeGradientDescent {
    pub fn new(n_iters: usize, lr: f64, init_beta: f64) -> Self {
        Self {
            beta: init_beta,
            n_iters,
            lr,
        }
    }

    fn grad_function(&self, x: &[f64], y: &[f64], lambda2: f64) -> f64 {
        assert_eq!(x.len(), y.len(), "x and y must have the same length");
        let n: usize = x.len();
        let grad_mse: f64 = x
            .iter()
            .zip(y.iter())
            .map(|(xi, yi)| {
                let error = yi - self.beta * xi;
                2.0 * error * xi
            })
            .sum::<f64>()
            / (n as f64);

        -grad_mse + 2.0 * lambda2 * self.beta
    }

    fn fit(&mut self, x: &[f64], y: &[f64], lambda2: f64) {
        for _ in 0..self.n_iters {
            let grad = self.grad_function(x, y, lambda2);
            self.beta -= self.lr * grad;
        }
    }

    fn predict(&self, x: &[f64]) -> Vec<f64> {
        x.iter().map(|xi| self.beta * xi).collect()
    }
}
}

Example of usage

Here is how we can use our new Ridge estimator:

fn main() {
    let x = vec![1.0, 2.0, 3.0, 4.0];
    let y = vec![2.1, 4.1, 6.2, 8.3];

    let mut model = RidgeGradientDescent::new(1000, 0.01, 0.0);
    model.fit(&x, &y, 0.1);

    let predictions = model.predict(&x);
    println!("Predictions: {:?}", predictions);
}

Base Ridge model using traits

We implement two different Ridge estimators in the previous sections. While these implementations solve the same problem, their fit logic is different. To unify their interface and promote code reuse, we can leverage Rust's trait mechanism.

We define a common trait RidgeModel, which describes the shared behavior of any Ridge estimator:

#![allow(unused)]
fn main() {
pub trait RidgeModel {
    fn fit(&mut self, x: &[f64], y: &[f64], lambda2: f64);
    fn predict(&self, x: &[f64]) -> Vec<f64>;
}
}

Any type that implements this trait must provide a fit method to train the model and a predict method to make predictions.

Both implementations use the same logic to produce predictions from a scalar . We can move this logic to a shared helper function:

#![allow(unused)]
fn main() {
fn predict_from_beta(beta: f64, x: &[f64]) -> Vec<f64> {
    x.iter().map(|xi| beta * xi).collect()
}
}

Trait implementation: gradient descent method

Recall that our RidgeGradientDescent type is defined as follows:

#![allow(unused)]
fn main() {
pub struct RidgeGradientDescent {
    beta: f64,
    n_iters: usize,
    lr: f64,
}
}

We still need to define the constructor and the gradient function:

#![allow(unused)]
fn main() {
impl RidgeGradientDescent {
    pub fn new(n_iters: usize, lr: f64, init_beta: f64) -> Self {
        Self {
            beta: init_beta,
            n_iters,
            lr,
        }
    }

    fn grad_function(&self, x: &[f64], y: &[f64], lambda2: f64) -> f64 {
        assert_eq!(x.len(), y.len(), "x and y must have the same length");
        let n: usize = x.len();
        let grad_mse: f64 = x
            .iter()
            .zip(y.iter())
            .map(|(xi, yi)| {
                let error = yi - self.beta * xi;
                2.0 * error * xi
            })
            .sum::<f64>()
            / (n as f64);

        -grad_mse + 2.0 * lambda2 * self.beta
    }
}
}

Once this is done, we need to implement the required methods to be a RidgeModel:

#![allow(unused)]
fn main() {
impl RidgeModel for RidgeGradientDescent {
    fn fit(&mut self, x: &[f64], y: &[f64], lambda2: f64) {
        for _ in 0..self.n_iters {
            let grad = self.grad_function(x, y, lambda2);
            self.beta -= self.lr * grad;
        }
    }

    fn predict(&self, x: &[f64]) -> Vec<f64> {
        predict_from_beta(self.beta, x)
    }
}
}

Trait implementation: closed-form estimator

We do the same for the RidgeEstimator that uses the analytical formula:

#![allow(unused)]
fn main() {
pub struct RidgeEstimator {
    beta: f64,
}

impl RidgeEstimator {
    pub fn new(init_beta: f64) -> Self {
        Self { beta: init_beta }
    }
}

impl RidgeModel for RidgeEstimator {
    fn fit(&mut self, x: &[f64], y: &[f64], lambda2: f64) {
        let n: usize = x.len();
        assert_eq!(n, y.len(), "x and y must have the same length");

        let x_mean: f64 = x.iter().sum::<f64>() / n as f64;
        let y_mean: f64 = y.iter().sum::<f64>() / n as f64;

        let num: f64 = x
            .iter()
            .zip(y)
            .map(|(xi, yi)| (xi - x_mean) * (yi - y_mean))
            .sum::<f64>();

        let denom: f64 =
            x.iter().map(|xi| (xi - x_mean).powi(2)).sum::<f64>() + lambda2 * (n as f64);

        self.beta = num / denom;
    }

    fn predict(&self, x: &[f64]) -> Vec<f64> {
        predict_from_beta(self.beta, x)
    }
}
}

That's it ! The usage remains the same but we slighly refactored our code.

Generics: introduction

The aim of this section is to generalize our estimators so they work with any numeric type, not just f64. Rust makes this possible through generics and trait bounds. It's divided into 2 subsections:

  1. Generics & trait bounds: Introduces generics and trait bounds. The floating-point type f64 is replaced by a generic type F that can either be f32 or f64. In Rust, generic types have no behavior by default, and we must tell the compiler which traits F should implement.
  2. Closed-form solution: Explains how to implement the closed-form solution with generics and traits.

In the next final section, we finally explore how to use the external crate ndarray for linear algebra, and how to incorporate additional Rust features such as optional values, pattern matching, and error handling.

What we're building here

The aim of this chapter is to build a small crate with the following layout:

crates/ridge_1d_generic/
├── Cargo.toml
└── src
    ├── regressor.rs    # Closed-form solution of the Ridge estimator
    └── lib.rs          # Main entry point for the library

The module regressor.rs implements the closed-form Ridge estimator using the generic type Float. As usual, the resulting regressor, here called GenRidgeEstimator, is exposed through the library entry point lib.rs.

In contrast to the previous implementations, this can be used with f32 or f64 floating-point types.

Example of usage:

#![allow(unused)]
fn main() {
use ridge_1d_generic::GenRidgeEstimator;

let mut model: GenRidgeEstimator<f32> = GenRidgeEstimator::new(1.0);

let x: Vec<f32> = vec![1.0, 2.0];
let y: Vec<f32> = vec![0.1, 0.2];
let lambda2 = 0.001;

model.fit(&x, &y, lambda2);
let preds: Vec<f32> = model.predict(&x);
}

What are generics?

Generics let you write code that works with many types, not just one.

Instead of writing:

#![allow(unused)]
fn main() {
struct RidgeEstimator {
    beta: f64,
}
}

You can write:

#![allow(unused)]
fn main() {
use num_traits::Float;

struct RidgeEstimator<F> {
    beta: F,
}
}

Here, F is a type parameter — it could be f32, f64, or another type. In Rust, generic types have no behavior by default.

Bug

#![allow(unused)]
fn main() {
fn sum(xs: &[F]) -> F {
    xs.iter().sum() // This will not compile
}
}

The compiler gives an error: "F might not implement Sum, so I don’t know how to .sum() over it."

Trait bounds

To fix that, we must tell the compiler which traits F should implement.

For example:

#![allow(unused)]
fn main() {
use num_traits::Float;
use std::iter::Sum;

impl<F: Float + Sum> RidgeModel<F> for GenRidgeEstimator<F> {
    ...
}
}

This means:

  • F must implement Float (it must behave like a floating point number: support powi, abs, etc.)
  • F must implement Sum (so we can sum an iterator of F)

This allows code like:

#![allow(unused)]
fn main() {
let mean = xs.iter().copied().sum::<F>() / F::from(xs.len()).unwrap();
}

Using generic bounds allows the estimator to work with f32, f64, or any numeric type implementing Float. The compiler can generate specialized code for each concrete type.

Generic Ridge estimator

In this section, we implement a Ridge estimator that works with f32 and f64 using generics and trait bounds.

To make this work, we need to import the following:

#![allow(unused)]
fn main() {
use num_traits::Float;
use std::iter::Sum;
}

Ridge model trait

We start by defining a trait, RidgeModel, that describes the core behavior expected from any Ridge regression model. We tell the compiler that F must implement the traits Float and Sum.

#![allow(unused)]
fn main() {
pub trait RidgeModel<F: Float + Sum> {
    /// Fits the model to the given data using Ridge regression.
    fn fit(&mut self, x: &[F], y: &[F], lambda2: F);

    /// Predicts output values for a slice of new input features.
    fn predict(&self, x: &[F]) -> Vec<F>;
}
}

Ridge estimator structure

We next define a Ridge structure as usual but using our generic type F. The model only stores the Ridge coefficient beta.

#![allow(unused)]
fn main() {
pub struct GenRidgeEstimator<F: Float + Sum> {
    pub beta: F,
}
}

We implement the constructor as usual as well.

#![allow(unused)]
fn main() {
impl<F: Float + Sum> GenRidgeEstimator<F> {
    /// Creates a new estimator with the given initial beta coefficient.
    pub fn new(init_beta: F) -> Self {
        Self { beta: init_beta }
    }
}
}

Fit and predict methods

We can finally implement the trait RidgeModel for our GenRidgeEstimator.

#![allow(unused)]
fn main() {
impl<F: Float + Sum> RidgeModel<F> for GenRidgeEstimator<F> {
    /// Fits the Ridge regression model to 1D data using closed-form solution.
    ///
    /// This method computes the regression coefficient `beta` by minimizing
    /// the Ridge-regularized least squares loss.
    ///
    /// # Arguments
    /// - `x`: Input features.
    /// - `y`: Target values.
    /// - `lambda2`: The regularization parameter (λ²).
    fn fit(&mut self, x: &[F], y: &[F], lambda2: F) {
        let n: usize = x.len();
        let n_f: F = F::from(n).unwrap();
        assert_eq!(x.len(), y.len(), "x and y must have the same length");

        let x_mean: F = x.iter().copied().sum::<F>() / n_f;
        let y_mean: F = y.iter().copied().sum::<F>() / n_f;

        let num: F = x
            .iter()
            .zip(y.iter())
            .map(|(xi, yi)| (*xi - x_mean) * (*yi - y_mean))
            .sum::<F>();

        let denom: F = x.iter().map(|xi| (*xi - x_mean).powi(2)).sum::<F>() + lambda2 * n_f;

        self.beta = num / denom;
    }

    /// Applies the trained model to input features to generate predictions.
    ///
    /// # Arguments
    /// - `x`: Input features to predict from.
    ///
    /// # Returns
    /// A vector of predicted values, one for each input in `x`.
    fn predict(&self, x: &[F]) -> Vec<F> {
        x.iter().map(|xi| *xi * self.beta).collect()
    }
}
}

Notice that the trait bounds <F: Float + Sum> RidgeModel<F> are defined after the name of a trait or struct, or right next to an impl.

Why do we need the Sum trait bound

Without Sum, the compiler does not allow .sum() on iterators of F. Try removing Sum from the bound:

#![allow(unused)]
fn main() {
impl<F: Float> RidgeModel<F> for GenRidgeEstimator<F>
}

And keep a call to .sum(). The compiler should complain:

error[E0599]: the method `sum` exists for iterator `std::slice::Iter<'_, F>`,
              but its trait bounds were not satisfied

The copied() method

The copied() method in .iter().copied().sum::<F>() is necessary because we're iterating over a slice of F, and F is a generic type that implements the Copy trait but not the Clone trait by default.

Without this, x.iter() yields references &F while sum::<F>() expects owned values of type F. We could have used cloned() instead but since Float already requires Copy, this works without adding the Clone trait bound.

Note that in the predict function, we don't need to use Copy because we manually dereference each item, *xi, inside the .map(). It could have been possible to use copied() there as well and modify the mapping closure accordingly.

The unwrap() method

The unwrap in F::from(n).unwrap(). The length n of the slice x is of type usize, as usual. We need n_f to be of type F so that we can perform operations like division with other F-typed values.

The conversion is done using F::from(n) which returns an Option<F>, not a plain F. We assumed that the conversion always succeeds or crashes by using unwrap(). Since n is from x.len(), it might easily be representable as f32 or f64, so unwrapping seems safe.

Note that we could have handled the error explicitly:

#![allow(unused)]
fn main() {
let n_f: F = F::from(n).expect("Length too large to convert to float");

</div>
</div>
}

Using ndarray, Option, and error handling

This section introduces several important features of the language:

  • Using the ndarray crate for numerical arrays
  • Representing optional values with Option, Some, and None
  • Using pattern matching with match
  • Handling errors using Box<dyn Error> and .into()
  • Automatically deriving common trait implementations using #[derive(...)]

Motivation

In previous sections, we worked with Vec<f64> and returned plain values. In practice, we might need:

  • Efficient linear algebra tools, provided by external crates such as ndarray and nalgebra
  • A way to represent "fitted" or "not fitted" states, using Option<f64>
  • A way to return errors when something goes wrong, using Result<_, _>>
  • Automatically implementing traits like Debug, Clone, and Default to simplify testing, debugging, and construction

We combine these in the implementation of the analytical RidgeEstimator. You can have a look to the full code below before we go through the main features step by step.

The full code : regressor.rs
#![allow(unused)]
fn main() {
use ndarray::Array1;

/// A Ridge regression estimator using `ndarray` for vectorized operations.
///
/// This version supports fitting and predicting using `Array1<f64>` arrays.
/// The coefficient `beta` is stored as an `Option<f64>`, allowing the model
/// to represent both fitted and unfitted states.
// ANCHOR: struct
#[derive(Debug, Clone, Default)]
pub struct RidgeEstimator {
    pub beta: Option<f64>,
}
// ANCHOR_END: struct

// ANCHOR: ridge_estimator_impl_new_fit
impl RidgeEstimator {
    /// Creates a new, unfitted Ridge estimator.
    ///
    /// # Returns
    /// A `RidgeEstimator` with `beta` set to `None`.
    pub fn new() -> Self {
        Self { beta: None }
    }

    /// Fits the Ridge regression model using 1D input and output arrays.
    ///
    /// This function computes the coefficient `beta` using the closed-form
    /// solution with L2 regularization.
    ///
    /// # Arguments
    /// - `x`: Input features as a 1D `Array1<f64>`.
    /// - `y`: Target values as a 1D `Array1<f64>`.
    /// - `lambda2`: The regularization strength (λ²).
    pub fn fit(&mut self, x: &Array1<f64>, y: &Array1<f64>, lambda2: f64) {
        let n: usize = x.len();
        assert!(n > 0);
        assert_eq!(x.len(), y.len(), "x and y must have the same length");

        // mean returns None if the array is empty, so we need to unwrap it
        let x_mean: f64 = x.mean().unwrap();
        let y_mean: f64 = y.mean().unwrap();

        let num: f64 = (x - x_mean).dot(&(y - y_mean));
        let denom: f64 = (x - x_mean).mapv(|z| z.powi(2)).sum() + lambda2 * (n as f64);

        self.beta = Some(num / denom);
    }
}
// ANCHOR_END: ridge_estimator_impl_new_fit

// ANCHOR: ridge_estimator_impl_predict
impl RidgeEstimator {
    /// Predicts target values given input features.
    ///
    /// # Arguments
    /// - `x`: Input features as a 1D array.
    ///
    /// # Returns
    /// A `Result` containing the predicted values, or an error if the model
    /// has not been fitted.
    pub fn predict(&self, x: &Array1<f64>) -> Result<Array1<f64>, String> {
        match &self.beta {
            Some(beta) => Ok(*beta * x),
            None => Err("Model not fitted".to_string()),
        }
    }
}
// ANCHOR_END: ridge_estimator_impl_predict

// ANCHOR: tests
#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn test_ridge_estimator_constructor() {
        let model = RidgeEstimator::new();
        assert_eq!(model.beta, None, "beta is expected to be None");
    }

    #[test]
    fn test_unfitted_estimator() {
        let model = RidgeEstimator::new();
        let x: Array1<f64> = array![1.0, 2.0];
        let result: Result<Array1<f64>, String> = model.predict(&x);

        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), "Model not fitted");
    }

    #[test]
    fn test_ridge_estimator_solution() {
        let x: Array1<f64> = array![1.0, 2.0];
        let y: Array1<f64> = array![0.1, 0.2];
        let true_beta: f64 = 0.1;
        let lambda2: f64 = 0.0;

        let mut model = RidgeEstimator::new();
        model.fit(&x, &y, lambda2);

        assert!(model.beta.is_some(), "beta is expected to be Some(f64)");

        assert!(
            (true_beta - model.beta.unwrap()).abs() < 1e-6,
            "Estimate {} not close enough to true solution {}",
            true_beta,
            model.beta.unwrap()
        );
    }
}
// ANCHOR_END: tests
}

What we're building here

The aim of this chapter is to build a small crate with the following layout:

crates/ridge_1d_ndarray/
├── Cargo.toml
└── src
    ├── regressor.rs    # Closed-form solution of the Ridge estimator
    └── lib.rs          # Main entry point for the library

Again, the module regressor.rs implements a RidgeEstimator type. We end up with the following user interface:

#![allow(unused)]
fn main() {
use ndarray::array;
use regressor::RidgeEstimator;

let mut model = RidgeEstimator::new();

let x = array![1.0, 2.0];
let y = array![0.1, 0.2];
let lambda2 = 0.001;

model.fit(&x, &y, lambda2);
let preds = model.predict(&x);

match model.beta {
    Some(beta) => println!("Learned beta: {beta}, true solution: 0.1!"),
    None => println!("Model not fitted!"),
}
}

Optional model state and #[derive(...)] for common traits

In addition of using ndarray instead of Vec<f64>, we also slightly modify the model struct.

To make our RidgeEstimator struct more ergonomic, we derive a few useful traits: Debug, Clone, and Default:

  • The Debug trait allows us to print the struct for inspection using println!("{:?}", ...), which is helpful during development.
  • Clone lets us duplicate the struct, which is often needed in data processing pipelines.
  • Default enables us to create a default value using RidgeEstimator::default(), which internally calls the new() method we define.
#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Default)]
pub struct RidgeEstimator {
    pub beta: Option<f64>,
}
}

The line

#![allow(unused)]
fn main() {
beta: Option<f64>
}

means beta can be either:

  • Some(value): if the model is trained
  • None: if the model has not been fitted yet

This way, we can explicitly model the fact that the estimator may not be fitted yet (i.e., no coefficients computed). When we initialize the model, we set beta to None as follows:

#![allow(unused)]
fn main() {
impl RidgeEstimator {
    /// Creates a new, unfitted Ridge estimator.
    ///
    /// # Returns
    /// A `RidgeEstimator` with `beta` set to `None`.
    pub fn new() -> Self {
        Self { beta: None }
    }
}
}

Switching to ndarray

Rust’s standard library does not include built-in numerical computing tools. The ndarray crate provides efficient n-dimensional arrays and vectorized operations.

This example uses:

  • Array1<f64> for 1D arrays
  • .mean(), .dot(), and .mapv() for basic mathematical operations
  • Broadcasting (x * beta) for scalar–array multiplication

Fit function

The impl of the fit function is shown below. As you can see, we essentially replaced Vec<f64> by Array1<f64> here and there. This allows us to rely on .mean, dot, or mapv to perform basic linear algebra. Given that self.beta is defined as an Option<f64>, we return Some(num / denom), from which Rust can infer the type Some(f64).

#![allow(unused)]
fn main() {
impl RidgeEstimator {
    /// Creates a new, unfitted Ridge estimator.
    ///
    /// # Returns
    /// A `RidgeEstimator` with `beta` set to `None`.
    pub fn new() -> Self {
        Self { beta: None }
    }

    /// Fits the Ridge regression model using 1D input and output arrays.
    ///
    /// This function computes the coefficient `beta` using the closed-form
    /// solution with L2 regularization.
    ///
    /// # Arguments
    /// - `x`: Input features as a 1D `Array1<f64>`.
    /// - `y`: Target values as a 1D `Array1<f64>`.
    /// - `lambda2`: The regularization strength (λ²).
    pub fn fit(&mut self, x: &Array1<f64>, y: &Array1<f64>, lambda2: f64) {
        let n: usize = x.len();
        assert!(n > 0);
        assert_eq!(x.len(), y.len(), "x and y must have the same length");

        // mean returns None if the array is empty, so we need to unwrap it
        let x_mean: f64 = x.mean().unwrap();
        let y_mean: f64 = y.mean().unwrap();

        let num: f64 = (x - x_mean).dot(&(y - y_mean));
        let denom: f64 = (x - x_mean).mapv(|z| z.powi(2)).sum() + lambda2 * (n as f64);

        self.beta = Some(num / denom);
    }
}
}

Note

ndarray can also handle higher-dimensional vectors and matrices with types like Array2 for 2D arrays. This makes it a powerful choice for implementing linear algebra operations and machine learning models in Rust, where you may generalize from 1D to multi-dimensional data.

Pattern matching

In Rust, the match keyword is used to compare a value against a set of patterns and execute code based on which pattern matches. This is especially useful with enums like Option.

#![allow(unused)]
fn main() {
let maybe_number: Option<i32> = Some(42);

match maybe_number {
    Some(n) => println!("The number is: {}", n),
    None => println!("No number available."),
}
}

This pattern ensures we safely handle both the presence and absence of a value.

Predict function

We use the exact same technique in our model to check whether it has been fitted. Since beta is of type Option<f64>, we can match on its value to determine whether a prediction can be made:

#![allow(unused)]
fn main() {
match self.beta {
    Some(beta) => Ok(x * beta),
    None => Err("Model not fitted".to_string()),
}
}

The full function takes this form:

#![allow(unused)]
fn main() {
impl RidgeEstimator {
    /// Predicts target values given input features.
    ///
    /// # Arguments
    /// - `x`: Input features as a 1D array.
    ///
    /// # Returns
    /// A `Result` containing the predicted values, or an error if the model
    /// has not been fitted.
    pub fn predict(&self, x: &Array1<f64>) -> Result<Array1<f64>, String> {
        match &self.beta {
            Some(beta) => Ok(*beta * x),
            None => Err("Model not fitted".to_string()),
        }
    }
}
}

Here, we also decide to explicitly raise an error if the model has not been fitted. To do this in a type-safe way, we use Rust’s Result enum, which is commonly used for functions that may fail. The Result enum can be either Ok(value) (indicating success) or Err(error) (indicating failure). More details about error handling are given in the next section.

Error handling with Result

We decided to raise an error if the model hasn't been fitted yet using the pattern matching:

#![allow(unused)]
fn main() {
match self.beta {
    Some(beta) => Ok(x * beta),
    None => Err("Model not fitted".to_string()),
}
}

This makes sense because our predict function returns a Result<Array1<f64>, String>. We use the to_string() method to convert the string literal to a regular String as requested. In practice, the user can use this function as follows:

#![allow(unused)]
fn main() {
let y_pred = model.predict(&x).unwrap();
}

which will panic if the model is not fitted, or

let y_pred = model.predict(&x).expect("Model not fitted yet");

to add a custom error message. These methods will make the code crash. Another strategy is to handle the Result with a match too, i.e.,

#![allow(unused)]
fn main() {
match model.predict(&x) {
    Ok(y_pred) => println!("Predicted values: {:?}", y_pred),
    Err(e) => eprintln!("Prediction failed: {}", e),
}
}

In summary, we have two matches that serve different roles:

  • Internal match: is beta available ?
  • External match: did predict work ?

We could also handle other kinds of errors such as dimensionality mismatch. To do, we can implement our own types of errors.

More advanced error handling

We could have gone even further by defining a custom ModelError type as follows.

#![allow(unused)]
fn main() {
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ModelError {
    #[error("Model is not fitted yet")]
    NotFitted,
    #[error("Dimension mismatch")]
    DimensionMismatch,
}
}

This approach uses the thiserror crate to simplify the implementation of the standard Error trait.

By deriving #[derive(Debug, Error)] and annotating each variant with #[error("...")], we define error messages rightaway.

The predict function would be rewritten as:

#![allow(unused)]
fn main() {
pub fn predict(&self, x: &Array1<f64>) -> Result<f64, ModelError> {
    match &self.beta {
        Some(beta) => {
            if beta.len() != x.len() {
                return Err(ModelError::DimensionMismatch);
            }
            Ok(beta.dot(x))
        }
        None => Err(ModelError::NotFitted),
    }
}
}

and could be used as follows:

#![allow(unused)]
fn main() {
match model.predict(&x) {
    Ok(y_pred) => println!("Prediction: {}", y_pred),
    Err(ModelError::NotFitted) => eprintln!("Model is not fitted yet."),
    Err(ModelError::DimensionMismatch) => eprintln!("Input dimension doesn't match."),
}
}

Adding tests

Tests can be included in the same file as the code using the #[cfg(test)] module. Each test function is annotated with #[test]. Inside a test, you can use assert_eq!, assert!, or similar macros to validate expected behavior.

The full test module can be seen below. We go through each of them in the sequel of this section.

Full test module
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn test_ridge_estimator_constructor() {
        let model = RidgeEstimator::new();
        assert_eq!(model.beta, None, "beta is expected to be None");
    }

    #[test]
    fn test_unfitted_estimator() {
        let model = RidgeEstimator::new();
        let x: Array1<f64> = array![1.0, 2.0];
        let result: Result<Array1<f64>, String> = model.predict(&x);

        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), "Model not fitted");
    }

    #[test]
    fn test_ridge_estimator_solution() {
        let x: Array1<f64> = array![1.0, 2.0];
        let y: Array1<f64> = array![0.1, 0.2];
        let true_beta: f64 = 0.1;
        let lambda2: f64 = 0.0;

        let mut model = RidgeEstimator::new();
        model.fit(&x, &y, lambda2);

        assert!(model.beta.is_some(), "beta is expected to be Some(f64)");

        assert!(
            (true_beta - model.beta.unwrap()).abs() < 1e-6,
            "Estimate {} not close enough to true solution {}",
            true_beta,
            model.beta.unwrap()
        );
    }
}
}

Recall that tests can be executed by running cargo test.

Testing the constructor

As a first simple test, we check that beta of a new RidgeEstimator is None.

#![allow(unused)]
fn main() {
#[test]
fn test_ridge_estimator_constructor() {
    let model = RidgeEstimator::new();
    assert_eq!(model.beta, None, "beta is expected to be None");
}
}

Testing an unfitted model

As a second test, we check that the predict function returns error if the model is unfitted.

#![allow(unused)]
fn main() {
#[test]
fn test_unfitted_estimator() {
    let model = RidgeEstimator::new();
    let x: Array1<f64> = array![1.0, 2.0];
    let result: Result<Array1<f64>, String> = model.predict(&x);

    assert!(result.is_err());
    assert_eq!(result.unwrap_err(), "Model not fitted");
}
}

Testing a fitted model

Finally, we check that a fitted model returns a Some(f64) and that the solution is close to the known value.

#![allow(unused)]
fn main() {
#[test]
fn test_ridge_estimator_solution() {
    let x: Array1<f64> = array![1.0, 2.0];
    let y: Array1<f64> = array![0.1, 0.2];
    let true_beta: f64 = 0.1;
    let lambda2: f64 = 0.0;

    let mut model = RidgeEstimator::new();
    model.fit(&x, &y, lambda2);

    assert!(model.beta.is_some(), "beta is expected to be Some(f64)");

    assert!(
        (true_beta - model.beta.unwrap()).abs() < 1e-6,
        "Estimate {} not close enough to true solution {}",
        true_beta,
        model.beta.unwrap()
    );
}
}

Simple optimizers

This chapter explores how to implement a small module of optimization algorithms in Rust. It is divided into three sections:

  • In the first section, we begin by defining a common interface for optimizers and show how different strategies like gradient descent and momentum-based methods can be implemented using Rust's trait system.
  • In the second section, we explore an alternative design using enums, which can be helpful when working with simpler control flow or dynamic dispatch.
  • In the last section, we demonstrate how to replace Vec<f64> with ndarray structures, which allows for more expressive and efficient numerical code, especially for larger-scale or matrix-based computations.

The goal is to gradually expose the design space for writing numerical algorithms idiomatically in Rust.

In each section, we implement a small crate with the following layout:

├── Cargo.toml
└── src
    ├── optimizers.rs
    └── lib.rs

The module optimizers.rs implements classical gradient descent with and without momentum, and eventually a Nesterov accelerated gradient descent.

Optimizers using traits

This chapter illustrates how to use traits for implementing a module of optimizers. This approach is useful when you want polymorphism or when each optimizer requires its own state and logic.

It's similar to what you might do in other languages such as Python or C++, and it's a good fit for applications that involve multiple algorithm variants.

Trait definition

We define a common trait Optimizer, which describes the shared behavior of any optimizer. Let's assume that our optimizers only need a step function.

#![allow(unused)]
fn main() {
pub trait Optimizer {
    /// Performs a single optimization step.
    ///
    /// # Arguments
    /// - `weights`: Mutable slice of parameters to be updated.
    /// - `grads`: Slice of gradients corresponding to the weights.
    fn step(&mut self, weights: &mut [f64], grads: &[f64]);
}
}

Any type that implements this trait must provide a step method. Note that we also made the reference &self mutable to illustrate that we can update internal state variables. The weights are also given as a mutable reference which means the step function makes in-place updates instead of reallocating a new buffer.

Let's illustrate how to use this by implementing two optimizers: gradient descent with and without momentum.

Gradient descent

Recall that the gradient descent algorithm is given by:

where denotes the step size, and is the objective function to minimize. We first define the structure for the gradient descent algorithm. It only stores the learning rate as a f64.

#![allow(unused)]
fn main() {
pub struct GradientDescent {
    pub learning_rate: f64,
}
}

We then implement a constructor. In this case, it simply consists of choosing the learning rate.

#![allow(unused)]
fn main() {
impl GradientDescent {
    /// Creates a new gradient descent optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size used to update weights.
    pub fn new(learning_rate: f64) -> Self {
        Self { learning_rate }
    }
}
}

Next, we implement the step method required by the Optimizer trait:

#![allow(unused)]
fn main() {
impl Optimizer for GradientDescent {
    /// Applies the gradient descent step to each weight.
    ///
    /// Each weight is updated as: `w ← w - learning_rate * grad`
    fn step(&mut self, weights: &mut [f64], grads: &[f64]) {
        for (w, g) in weights.iter_mut().zip(grads.iter()) {
            *w -= self.learning_rate * g;
        }
    }
}
}

This function updates each entry of weights by looping over the elements and applying the gradient descent update. The weight w inside the loop must be dereferenced as it is passed as a mutable reference.

We use elementwise operations because Vec doesn't provide built-in arithmetic methods. External crates such as ndarray or nalgebra could help write this more expressively.

Gradient descent with momentum

Recall that the gradient descent algorithm with momentum is given by:

where , , and denote the velocity, momentum and step size, respectively. The structure we define stores the learning rate, the momentum factor, and an internal velocity buffer:

#![allow(unused)]
fn main() {
pub struct Momentum {
    pub learning_rate: f64,
    pub momentum: f64,
    pub velocity: Vec<f64>,
}
}

We define the constructor by taking the required parameters, and we initialize the velocity to a zero vector:

#![allow(unused)]
fn main() {
impl Momentum {
    /// Creates a new momentum optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size used to update weights.
    /// - `momentum`: Momentum coefficient (typically between 0.8 and 0.99).
    /// - `dim`: Dimension of the parameter vector, used to initialize velocity.
    pub fn new(learning_rate: f64, momentum: f64, dim: usize) -> Self {
        Self {
            learning_rate,
            momentum,
            velocity: vec![0.0; dim],
        }
    }
}
}

The step function is slightly more complex, as it performs elementwise operations over the weights, velocity, and gradients:

#![allow(unused)]
fn main() {
impl Optimizer for Momentum {
    /// Applies the momentum update step.
    ///
    /// Each step uses the update rule:
    /// ```text
    /// v ← momentum * v + learning_rate * grad
    /// w ← w - v
    /// ```
    fn step(&mut self, weights: &mut [f64], grads: &[f64]) {
        for ((w, g), v) in weights
            .iter_mut()
            .zip(grads.iter())
            .zip(self.velocity.iter_mut())
        {
            *v = self.momentum * *v + self.learning_rate * *g;
            *w -= *v;
        }
    }
}
}

The internal state of the velocity is updated as well, which is possible because we pass a mutable reference &self. At this point, we've defined two optimizers using structs and a shared trait. To complete the module, we define a training loop that uses any optimizer implementing the trait.

API and usage

We expose the training loop in lib.rs as the public API. The function run_optimization takes a generic Optimizer, a gradient function, an initial weight vector, and a maximum number of iterations.

#![allow(unused)]
fn main() {
pub mod optimizers;
use optimizers::Optimizer;

pub fn run_optimization<Opt: Optimizer>(
    optimizer: &mut Opt,
    weights: &mut [f64],
    grad_fn: impl Fn(&[f64]) -> Vec<f64>,
    num_steps: usize,
) {
    for _ in 0..num_steps {
        let grads = grad_fn(weights);
        optimizer.step(weights, &grads);
    }
}
}

Here, the trait bound <Opt: Optimizer> tells the compiler that type of the given optimizer must implement the trait Optimizer. This ensures that optimizer has the required step function.

Example of usage

Here’s a simple example where we minimize the function in :

#![allow(unused)]
fn main() {
use traits_based::optimizers::Momentum;

fn grad_fn(w: &[f64]) -> Vec<f64> {
    w.iter().map(|wi| 2.0 * wi.powi(2)).collect()
}

let n: usize = 10;
let mut weights = vec![1.0; n];
let mut optimizer = Momentum::new(0.01, 0.9, n);

run_optimization(&mut optimizer, &mut weights, grad_fn, 100);

// Final weights after optimization
println!("{:?}", weights); 
}

Some final reminders:

  • Both weights and optimizer must be mutable, because we perform in-place updates.
  • We pass mutable references into run_optimization, matching its function signature.
  • The example uses a closure-based gradient function, which you can easily replace.

Adding tests

In order to test our optimizers, we propose to have a look at how to implement tests and run them.

How to write tests in Rust

Tests can be included in the same file as the code using the #[cfg(test)] module. Each test function is annotated with #[test]. Inside a test, you can use assert_eq!, assert!, or similar macros to validate expected behavior.

What we test

We implemented a few tests to check:

  • That the constructors return the expected variant with the correct parameters
  • That the step method modifies weights as expected
  • That repeated calls to step update the internal state correctly (e.g., momentum's velocity)
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gradient_descent_constructor() {
        let optimizer = GradientDescent::new(1e-3);
        assert_eq!(1e-3, optimizer.learning_rate);
    }

    #[test]
    fn test_step_gradient_descent() {
        let mut opt = GradientDescent::new(0.1);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);

        assert_eq!(weights, vec![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Momentum::new(0.01, 0.9, 10);
        match opt {
            Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                assert_eq!(learning_rate, 0.01);
                assert_eq!(momentum, 0.9);
                assert_eq!(velocity.len(), 10);
            }
        }
    }

    #[test]
    fn test_step_momentum() {
        let mut opt = Momentum::new(0.1, 0.9, 3);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);
        assert_eq!(weights, vec![0.95, 1.95, 2.95]);

        opt.step(&mut weights, &grads);
        assert!(
            weights
                .iter()
                .zip(vec![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
}

Some notes:

  • This module is added in the same file where the optimizers are implemented.
  • The line use super::*; tells the compiler to import all the stuff available in the module.

How to run the tests

To run the tests from the command line, use:

cargo test

This will automatically find and execute all test functions in the project. You should see output like:

running 4 tests
test tests::test_gradient_descent_constructor ... ok
test tests::test_momentum_constructor ... ok
test tests::test_step_gradient_descent ... ok
test tests::test_step_momentum ... ok

If any test fails, Cargo will show which assertion failed and why.

Optimizers as enums with internal state and methods

This chapter builds on the previous enum-based optimizer design. We now give each variant its own internal state and encapsulate behavior using methods. This pattern is useful when you want enum-based control flow with encapsulated logic.

Defining the optimizer enum

Each optimizer variant includes its own parameters and, when needed, its internal state.

#![allow(unused)]
fn main() {
#[derive(Debug, Clone)]
pub enum Optimizer {
    /// Gradient Descent optimizer with a fixed learning rate.
    GradientDescent { learning_rate: f64 },
    /// Momentum-based optimizer with velocity tracking.
    Momentum {
        learning_rate: f64,
        momentum: f64,
        velocity: Vec<f64>,
    },
}
}

Here, GradientDescent stores only the learning rate, while Momentum additionally stores its velocity vector.

Instantiation

To instantiate a GradientDescent, the user has to write:

#![allow(unused)]
fn main() {
let optimizer = Optimizer::GradientDescent {
    learning_rate: 0.1,
};
}

For the momentum-based gradient descent, the instantiation becomes more cumbersome:

#![allow(unused)]
fn main() {
let optimizer = Optimizer::Momentum {
    learning_rate: 0.1,
    momentum: 0.9,
    velocity: vec![0.0; 3],
};
}

To make this more user-friendly, we can define more convenient constrcutors such as:

#![allow(unused)]
fn main() {
let optimizer = Optimizer::gradient_descent(0.1);
}

and

#![allow(unused)]
fn main() {
let optimizer = Optimizer::momentum(0.1, 0.9, 3);
}

This can be achieved by adding two implementations:

#![allow(unused)]
fn main() {
impl Optimizer {
    /// Creates a new Gradient Descent optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size for the parameter updates.
    pub fn gradient_descent(learning_rate: f64) -> Self {
        Self::GradientDescent { learning_rate }
    }

    /// Creates a new Momentum optimizer.
    ///
    /// # Arguments
    /// - `learning_rate`: Step size for the updates.
    /// - `momentum`: Momentum coefficient.
    /// - `dim`: Number of parameters (used to initialize velocity vector).
    pub fn momentum(learning_rate: f64, momentum: f64, dim: usize) -> Self {
        Self::Momentum {
            learning_rate,
            momentum,
            velocity: vec![0.0; dim],
        }
    }
}
}

This is optional but it helps create optimizers easily.

Implementing the step method

The step method is responsible for updating the model parameters (weights) according to the specific optimization strategy in use. It uses pattern matching to dispatch the correct behavior depending on whether the optimizer is a GradientDescent or a Momentum variant.

#![allow(unused)]
fn main() {
impl Optimizer {
    /// Applies a single optimization step, depending on the variant.
    ///
    /// # Arguments
    /// - `weights`: Mutable slice of model parameters to be updated.
    /// - `grads`: Gradient slice with same shape as `weights`.
    pub fn step(&mut self, weights: &mut [f64], grads: &[f64]) {
        match self {
            Optimizer::GradientDescent { learning_rate } => {
                for (w, g) in weights.iter_mut().zip(grads.iter()) {
                    *w -= *learning_rate * *g;
                }
            }
            Optimizer::Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                for ((w, g), v) in weights
                    .iter_mut()
                    .zip(grads.iter())
                    .zip(velocity.iter_mut())
                {
                    *v = *momentum * *v + *learning_rate * *g;
                    *w -= *v;
                }
            }
        }
    }
}
}

The match expression identifies which optimizer variant is being used. This pattern can be a clean alternative to trait-based designs when you want:

  • A small number of well-known variants
  • Built-in state encapsulation
  • Exhaustive handling via pattern matching

It keeps related logic grouped under one type and can be extended easily with new optimizers.

API and usage

Once you have implemented the optimizers and the step logic, it's time to expose a public API to run optimization from your crate's lib.rs. This typically involves defining a helper function like run_optimization.

Define run_optimization in lib.rs

Similarly to the trait-based implementation, we can define a function run_optimization that performs the optimization. However, here, Optimizer is a enum instead of a trait, hence we can't define a generic type and write <Opt: Optimizer> (see trait-based run_optimization function if you don't remember.). Instead, we simply pass a concrete mutable reference.

#![allow(unused)]
fn main() {
pub mod optimizers;

use optimizers::Optimizer;

/// Runs an optimization loop over multiple steps.
///
/// # Arguments
/// - `optimizer`: The optimizer to use (e.g., GradientDescent, Momentum).
/// - `weights`: Mutable reference to the weights vector to be optimized.
/// - `grad_fn`: A closure that computes the gradient given the current weights.
/// - `num_steps`: Number of iterations to run.
pub fn run_optimization(
    optimizer: &mut Optimizer,
    weights: &mut [f64],
    grad_fn: impl Fn(&[f64]) -> Vec<f64>,
    num_steps: usize,
) {
    for _ in 0..num_steps {
        let grads = grad_fn(weights);
        optimizer.step(weights, &grads);
    }
}
}

Example of usage

Here's a basic example of using run_optimization to minimize a simple quadratic loss.

fn main() {
    let grad_fn = |w: &[f64]| vec![2.0 * (w[0] - 3.0)];
    
    let mut weights = vec![0.0];
    let mut optimizer = Optimizer::gradient_descent(0.1);

    run_optimization(&mut optimizer, &mut weights, grad_fn, 100);

    println!("Optimized weight: {:?}", weights[0]);
}

Adding tests

We can easily adapt the tests we implemented for the trait-based version of the optimizers. Here, we rely on pattern matching to check the constructors.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_gradient_descent_constructor() {
        let opt = Optimizer::gradient_descent(0.01);
        match opt {
            Optimizer::GradientDescent { learning_rate } => {
                assert_eq!(learning_rate, 0.01);
            }
            _ => panic!("Expected GradientDescent optimizer"),
        }
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Optimizer::momentum(0.01, 0.9, 10);
        match opt {
            Optimizer::Momentum {
                learning_rate,
                momentum,
                velocity,
            } => {
                assert_eq!(learning_rate, 0.01);
                assert_eq!(momentum, 0.9);
                assert_eq!(velocity.len(), 10);
            }
            _ => panic!("Expected Momentum optimizer"),
        }
    }

    #[test]
    fn test_step_gradient_descent() {
        let mut opt = Optimizer::gradient_descent(0.1);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);

        assert_eq!(weights, vec![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_step_momentum() {
        let mut opt = Optimizer::momentum(0.1, 0.9, 3);
        let mut weights = vec![1.0, 2.0, 3.0];
        let grads = vec![0.5, 0.5, 0.5];

        opt.step(&mut weights, &grads);
        assert_eq!(weights, vec![0.95, 1.95, 2.95]);

        opt.step(&mut weights, &grads);
        assert!(
            weights
                .iter()
                .zip(vec![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
}

Optimizers using ndarray

This section introduces a modular and idiomatic way to implement optimization algorithms in Rust using ndarray and traits. It is intended for readers who are already comfortable with basic Rust syntax and want to learn how to build reusable, extensible components in numerical computing.

You can inspect the full module that we're about the break down over here:

Click to view optimizers.rs
#![allow(unused)]
fn main() {
use ndarray::Array;
use ndarray::Array1;
use ndarray::Zip;

/// Trait for optimizers that update parameters using gradients.
///
/// Implementors must define a `run` method that takes mutable weights,
/// a gradient function, and the number of iterations to run.
// ANCHOR: trait
pub trait Optimizer {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    );
}
// ANCHOR_END: trait

/// Basic Gradient Descent (GD) optimizer.
///
/// Updates parameters in the direction of the negative gradient scaled
/// by a fixed step size.
// ANCHOR: struct_gd
pub struct GD {
    step_size: f64,
}
// ANCHOR_END: struct_gd

/// Create a new gradient descent optimizer.
///
/// # Arguments
/// - `step_size`: the learning rate.
// ANCHOR: impl_gd_new
impl GD {
    pub fn new(step_size: f64) -> Self {
        Self { step_size }
    }
}
// ANCHOR_END: impl_gd_new

/// Run the gradient descent optimizer.
///
/// For each step: `w ← w - step_size * grad(w)`
// ANCHOR: impl_gd_run
impl Optimizer for GD {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    ) {
        for _ in 0..n_steps {
            let grads = grad_fn(weights);
            weights.zip_mut_with(&grads, |w, &g| {
                *w -= self.step_size * g;
            });
        }
    }
}
// ANCHOR_END: impl_gd_run

/// Gradient descent with classical momentum.
///
/// Combines the previous velocity with the current gradient
/// to speed up convergence in convex problems.
// ANCHOR: struct_agd
pub struct Momentum {
    step_size: f64,
    momentum: f64,
}
// ANCHOR_END: struct_agd

/// Create a new Momentum optimizer.
///
/// # Arguments
/// - `step_size`: the learning rate.
/// - `momentum`: the momentum coefficient (typically between 0.8 and 0.99).
// ANCHOR: impl_agd_new
impl Momentum {
    pub fn new(step_size: f64, momentum: f64) -> Self {
        Self {
            step_size,
            momentum,
        }
    }
}
// ANCHOR_END: impl_agd_new

/// Run AGD with momentum.
///
/// For each step:
/// ```text
/// v ← momentum * v - step_size * grad(w)
/// w ← w + v
/// ```
// ANCHOR: impl_agd_run
impl Optimizer for Momentum {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    ) {
        let n: usize = weights.len();
        let mut velocity: Array1<f64> = Array::zeros(n);

        for _ in 0..n_steps {
            let grads = grad_fn(weights);
            for ((w, g), v) in weights
                .iter_mut()
                .zip(grads.iter())
                .zip(velocity.iter_mut())
            {
                *v = self.momentum * *v - self.step_size * g;
                *w += *v;
            }
        }
    }
}
// ANCHOR_END: impl_agd_run

/// Adaptive Accelerated Gradient Descent using Nesterov's method.
///
/// This optimizer implements the variant from smooth convex optimization literature,
/// where extrapolation is based on the difference between consecutive y iterates.
///
/// References:
/// - Beck & Teboulle (2009), FISTA (but without proximal operator)
/// - Nesterov's accelerated gradient (original formulation)
// ANCHOR: NAG_struct
pub struct NAG {
    step_size: f64,
}
// ANCHOR_END: NAG_struct

// ANCHOR: NAG_impl_new
impl NAG {
    /// Create a new instance of NAG with a given step size.
    ///
    /// The step size should be 1 / L, where L is the Lipschitz constant
    /// of the gradient of the objective function.
    pub fn new(step_size: f64) -> Self {
        Self { step_size }
    }
}
// ANCHOR_END: NAG_impl_new

/// Run the optimizer for `n_steps` iterations.
///
/// # Arguments
/// - `weights`: mutable reference to the parameter vector (x₀), will be updated in-place.
/// - `grad_fn`: a function that computes ∇f(x) for a given x.
/// - `n_steps`: number of optimization steps to perform.
///
/// This implementation follows:
///
///
/// y_{k+1} = x_k - α ∇f(x_k)
/// t_{k+1} = (1 + sqrt(1 + 4 t_k²)) / 2
/// x_{k+1} = y_{k+1} + ((t_k - 1)/t_{k+1}) * (y_{k+1} - y_k)
///
// ANCHOR: NAG_impl_run
impl Optimizer for NAG {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    ) {
        let mut t_k: f64 = 1.0;
        let mut y_k = weights.clone();

        for _ in 0..n_steps {
            let grad = grad_fn(weights);
            let mut y_next = weights.clone();
            Zip::from(&mut y_next).and(&grad).for_each(|y, &g| {
                *y -= self.step_size * g;
            });

            let t_next = 0.5 * (1.0 + (1.0 + 4.0 * t_k * t_k).sqrt());

            Zip::from(&mut *weights)
                .and(&y_next)
                .and(&y_k)
                .for_each(|x, &y1, &y0| {
                    *x = y1 + ((t_k - 1.0) / t_next) * (y1 - y0);
                });

            y_k = y_next;
            t_k = t_next;
        }
    }
}
// ANCHOR_END: NAG_impl_run

// ANCHOR: tests
#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn test_gradient_descent_constructor() {
        let optimizer = GD::new(1e-3);
        assert_eq!(1e-3, optimizer.step_size);
    }

    #[test]
    fn test_step_gradient_descent() {
        let opt = GD::new(0.1);
        let mut weights = array![1.0, 2.0, 3.0];
        let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5];
        opt.run(&mut weights, grad_fn, 1);

        assert_eq!(weights, array![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Momentum::new(0.01, 0.9);
        assert_eq!(
            opt.step_size, 0.01,
            "Expected step size to be 0.01 but got {}",
            opt.step_size
        );
        assert_eq!(
            opt.momentum, 0.9,
            "Expected momentum to be 0.9 but got {}",
            opt.momentum
        );
    }

    #[test]
    fn test_step_momentum() {
        let opt = Momentum::new(0.1, 0.9);
        let mut weights = array![1.0, 2.0, 3.0];
        let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5];

        opt.run(&mut weights, grad_fn, 2);
        assert!(
            weights
                .iter()
                .zip(array![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
// ANCHOR_END: tests
}

In this chapter we also implement the Nesterov Accelerated Gradient method.

Required imports

In this example, we need to import the following types and traits:

#![allow(unused)]
fn main() {
use ndarray::Array;
use ndarray::Array1;
use ndarray::Zip;
}

The Array type is a general-purpose n-dimensional array used for numerical computing. It provides a wide range of methods for array creation (zeros, ones, from_vec, etc.), manipulation, and broadcasting. Here, we primarily use it to initialize zero vectors for optimizer internals like velocity buffers.

We also import Array1, a type alias for one-dimensional arrays (Array<f64, Ix1>), since we're working with flat vectors of parameters or gradients.

Zip is a utility that enables element-wise operations across one or more arrays that we use for in-place updates.

Trait-based design

We define a trait called Optimizer to represent any optimizer that can update model weights based on gradients. In contrast to the previous sections where we mostly implemented stepfunctions, here the trait requires implementors to define a run method with the following signature:

#![allow(unused)]
fn main() {
pub trait Optimizer {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    );
}
}

This method takes:

  • A mutable reference to a vector of weights (Array1<f64>).
  • A function that computes the gradient of the loss with respect to the weights. This grad_fn function takes itself a borrowed reference to the weights &Array1<f64> and outputs a new array Array1<f64>.
  • The number of iterations to perform.

This trait run defines the whole optimization algorithm.

Gradient descent

The GD struct implements basic gradient descent with a fixed step size:

#![allow(unused)]
fn main() {
pub struct GD {
    step_size: f64,
}
}

It has a constructor:

#![allow(unused)]
fn main() {
impl GD {
    pub fn new(step_size: f64) -> Self {
        Self { step_size }
    }
}
}

And implements Optimizer by subtracting the gradient scaled by the step size from the weights at each iteration.

#![allow(unused)]
fn main() {
impl Optimizer for GD {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    ) {
        for _ in 0..n_steps {
            let grads = grad_fn(weights);
            weights.zip_mut_with(&grads, |w, &g| {
                *w -= self.step_size * g;
            });
        }
    }
}
}

Some notes:

  • At each iteration, we compute the gradient with let grads = grad_fn(weights), which is fine but it reallocates a new vector at each call. If we wanted to optimize the gradient computation, we could pre-allocate a buffer outside the loop and pass a mutable reference into the gradient function to avoid repeated allocations. This would require to change the signature of the grad_fn.

  • weights.zip_mut_with(&grads, |w, &g| {{ ... }}): This is a mutable zip operation from the ndarray crate. It walks over weights and grads, applying the closure to each pair.

  • zip_mut_with is a method defined by the Zip trait, which is implemented for ArrayBase, and in particular for Array1. That’s why we can call it directly on weights.

  • In the closure statement we wrote: |w, &g| *w -= self.step_size * g;. Here, w is a mutable reference to each weight element, so we dereference it using *w to update its value. The &g in the closure means we’re pattern-matching by reference to avoid cloning or copying each f64.

Gradient descent with momentum

The Momentum struct extends gradient descent with a classical momentum term:

#![allow(unused)]
fn main() {
pub struct Momentum {
    step_size: f64,
    momentum: f64,
}
}

It has a constructor:

#![allow(unused)]
fn main() {
impl Momentum {
    pub fn new(step_size: f64, momentum: f64) -> Self {
        Self {
            step_size,
            momentum,
        }
    }
}
}

This algorithm adds momentum to classical gradient descent. Instead of updating weights using just the current gradient, it maintains a velocity vector that accumulates the influence of past gradients. This helps smooth the trajectory and accelerates convergence on convex problems.

#![allow(unused)]
fn main() {
impl Optimizer for Momentum {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    ) {
        let n: usize = weights.len();
        let mut velocity: Array1<f64> = Array::zeros(n);

        for _ in 0..n_steps {
            let grads = grad_fn(weights);
            for ((w, g), v) in weights
                .iter_mut()
                .zip(grads.iter())
                .zip(velocity.iter_mut())
            {
                *v = self.momentum * *v - self.step_size * g;
                *w += *v;
            }
        }
    }
}
}

Some notes:

  • We initialize a vector of zeros to track the momentum (velocity) across steps. It has the same length as the weights. This is achieved with: let mut velocity: Array1<f64> = Array::zeros(n). Note that we could have defined the velocity as an internal state variable within the struct defintion.

  • We use a triple nested zip to unpack the values of the weights, gradients, and velocity: for ((w, g), v) in weights.iter_mut().zip(grads.iter()).zip(velocity.iter_mut()). Here,

    • weights.iter_mut() gives a mutable reference to each weight,
    • grads.iter() provides read-only access to each gradient,
    • velocity.iter_mut() allows in-place updates of the velocity vector.

    This pattern allows us to update everything in one pass, element-wise.

  • Within the nested zip closure, we update the velocity using the momentum term and current gradient: *v = self.momentum * *v - self.step_size * g;

  • The weight is updated using the new velocity: *w += *v;. Again, we dereference w because it's a mutable reference.

Nesterov Accelerated Gradient

This algorithm implements an accelerated method inspired by Nesterov’s momentum and the FISTA algorithm. The structure only stored the chosen step size.

#![allow(unused)]
fn main() {
pub struct NAG {
    step_size: f64,
}
}

It has a constructor:

#![allow(unused)]
fn main() {
impl NAG {
    /// Create a new instance of NAG with a given step size.
    ///
    /// The step size should be 1 / L, where L is the Lipschitz constant
    /// of the gradient of the objective function.
    pub fn new(step_size: f64) -> Self {
        Self { step_size }
    }
}
}

The key idea is to introduce an extrapolation step between iterates, controlled by a sequence t_k. This helps the optimizer "look ahead" and converge faster in smooth convex problems.

Update steps:

  • Compute a temporary point y_{k+1} by taking a gradient step from x_k.
  • Update the extrapolation coefficient t_{k+1}.
  • Combine y_{k+1} and y_k using a weighted average to get the new iterate x_{k+1}.
#![allow(unused)]
fn main() {
impl Optimizer for NAG {
    fn run(
        &self,
        weights: &mut Array1<f64>,
        grad_fn: impl Fn(&Array1<f64>) -> Array1<f64>,
        n_steps: usize,
    ) {
        let mut t_k: f64 = 1.0;
        let mut y_k = weights.clone();

        for _ in 0..n_steps {
            let grad = grad_fn(weights);
            let mut y_next = weights.clone();
            Zip::from(&mut y_next).and(&grad).for_each(|y, &g| {
                *y -= self.step_size * g;
            });

            let t_next = 0.5 * (1.0 + (1.0 + 4.0 * t_k * t_k).sqrt());

            Zip::from(&mut *weights)
                .and(&y_next)
                .and(&y_k)
                .for_each(|x, &y1, &y0| {
                    *x = y1 + ((t_k - 1.0) / t_next) * (y1 - y0);
                });

            y_k = y_next;
            t_k = t_next;
        }
    }
}
}

Some notes:

  • We deliberately re-allocate multiple variables within the for loop (grad, y_next, t_next) but we could have pre-allocated buffers before the for loop.

  • The algorithm keeps track of two sequences: the main iterate (weights) and the extrapolated one (y_k). Before starting the for loop, we initialize y_k by cloning the weights: let mut y_k = weights.clone();.

  • The gradient is evaluated at the current weights, as in standard gradient descent: let grad = grad_fn(weights);. Since weights is a mutable reference, we can pass it straightaway to our grad_fn.

  • A temporary variable to store the new extrapolated point. This is again a full allocation and clone for clarity. let mut y_next = weights.clone();.

  • We next compute: y_{k+1} = x_k - α ∇f(x_k) using an element-wise operation: Zip::from(&mut y_next).and(&grad).for_each(|y, &g| { *y -= self.step_size * g; });. This time, we rely on the Zip::from trait implement by ndarray.

  • The new weights are obtained by combining y_{k+1} and y_k. The triple zip walks over the current weights and both extrapolation points: Zip::from(&mut *weights)....

This optimizer is more involved than basic gradient descent but still relies on the same functional building blocks: closures, element-wise iteration, and vector arithmetic with ndarray.

API and usage

Here’s how you can use each of the three optimizers GD, Momentum, and NAG to minimize a simple quadratic function.

We'll try to minimize the function:

Its gradient is:

We expect convergence toward the vector [3.0, 3.0, 3.0].

Using gradient descent

use optimizers::GD;
use ndarray::{array, Array1};

fn main() {
    let mut weights = array![0.0, 0.0, 0.0];
    let grad_fn = |w: &Array1<f64>| w - 3.0;

    let gd = GD::new(0.1);
    gd.run(&mut weights, grad_fn, 100);

    println!("GD result: {:?}", weights);
}

Using momentum

use optimizers::Momentum;
use ndarray::{array, Array1};

fn main() {
    let mut weights = array![0.0, 0.0, 0.0];
    let grad_fn = |w: &Array1<f64>| w - 3.0;

    let momentum = Momentum::new(0.1, 0.9);
    momentum.run(&mut weights, grad_fn, 100);

    println!("Momentum result: {:?}", weights);
}

Using Nesterov’s Accelerated Gradient (NAG)

use optimizers::NAG;
use ndarray::{array, Array1};

fn main() {
    let mut weights = array![0.0, 0.0, 0.0];
    let grad_fn = |w: &Array1<f64>| w - 3.0;

    let nag = NAG::new(0.1);
    nag.run(&mut weights, grad_fn, 100);

    println!("NAG result: {:?}", weights);
}

Summary

This design demonstrates a few Rust programming techniques:

  • Traits for abstraction and polymorphism
  • Structs to encapsulate algorithm-specific state
  • Use of the ndarray crate for numerical data
  • Generic functions using closures for computing gradients

Adding tests

Finally, we can add tests to check out ndarray-based implementation.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn test_gradient_descent_constructor() {
        let optimizer = GD::new(1e-3);
        assert_eq!(1e-3, optimizer.step_size);
    }

    #[test]
    fn test_step_gradient_descent() {
        let opt = GD::new(0.1);
        let mut weights = array![1.0, 2.0, 3.0];
        let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5];
        opt.run(&mut weights, grad_fn, 1);

        assert_eq!(weights, array![0.95, 1.95, 2.95])
    }

    #[test]
    fn test_momentum_constructor() {
        let opt = Momentum::new(0.01, 0.9);
        assert_eq!(
            opt.step_size, 0.01,
            "Expected step size to be 0.01 but got {}",
            opt.step_size
        );
        assert_eq!(
            opt.momentum, 0.9,
            "Expected momentum to be 0.9 but got {}",
            opt.momentum
        );
    }

    #[test]
    fn test_step_momentum() {
        let opt = Momentum::new(0.1, 0.9);
        let mut weights = array![1.0, 2.0, 3.0];
        let grad_fn = |_w: &Array1<f64>| array![0.5, 0.5, 0.5];

        opt.run(&mut weights, grad_fn, 2);
        assert!(
            weights
                .iter()
                .zip(array![0.855, 1.855, 2.855])
                .all(|(a, b)| (*a - b).abs() < 1e-6)
        );
    }
}
}

Kernel Ridge regression

In this chapter, we implement Kernel Ridge Regression (KRR) in Rust using the ndarray and ndarray-linalg crates. The implementation is broken into the following sections:

  • In the Kernel module section, we define a Kernel trait and implement the radial basis function (RBF) kernel.

  • In the Gram matrix section, we construct the symmetric Gram matrix needed to solve the KRR problem using Array2 and ArrayView1.

  • In the KRR model section, we define the KRRModel struct and its constructor, making the model generic over any type that implements the Kernel trait.

  • In the fit function section, we implement the logic for training the model, including matrix assembly, regularization, and linear system solving. We introduce a custom error enum KRRFitError to manage common issues.

  • In the predict function section, we implement inference for new samples and introduce the KRRPredictError enum to handle the unfitted model case.

  • In the hyperparameter tuning section, we implement leave-one-out cross-validation (LOOCV) to select a good value for the kernel’s lengthscale.

At the end of the chapter, we obtain a small standalone crate with the following layout:

├── Cargo.toml
└── src
    ├── errors.rs
    ├── kernel.rs
    ├── lib.rs
    └── model.rs

where the Cargo.toml configuration file is given by:

[package]
name = "krr_ndarray"
version = "0.1.0"
edition = "2024"

[dependencies]
rustineers = { path = "../../" }
ndarray = "0.15.2"
ndarray-linalg = { version = "0.16", features = ["openblas-static"] }
thiserror = "1.0"

We enable the openblas-static feature to ensure OpenBLAS is built within the crate, avoiding reliance on system-wide BLAS libraries. The thiserror crate is used to define ergonomic and readable custom error types.

The kernel module

The kernel.rs module defines the core abstraction used to compute similarity between data points in Kernel Ridge Regression (KRR). This abstraction is formalized as a trait, and a specific instance of this trait is implemented using the radial basis function (RBF) kernel, a popular choice in kernel methods. The module also includes unit tests to validate correctness.

Click here to view to full module: kernel.rs. We break into down in the sequel of this section.
#![allow(unused)]
fn main() {
use ndarray::ArrayView1;

pub trait Kernel {
    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64;
}

#[derive(Clone)]
pub struct RBFKernel {
    pub lengthscale: f64,
}

impl RBFKernel {
    pub fn new(lengthscale: f64) -> Self {
        assert!(lengthscale > 0.0, "Lengthscale must be positive");
        Self { lengthscale }
    }
}

impl Kernel for RBFKernel {
    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
        let diff = &x - &y;
        (-diff.dot(&diff) / (2.0 * self.lengthscale.powi(2))).exp()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::array;

    #[test]
    fn test_rbf_kernel_xx() {
        let kernel = RBFKernel::new(1.0);
        let x = array![1.0, 2.0, 3.0];
        let kxx = kernel.compute(x.view(), x.view());
        assert_eq!(kxx, 1.0, "Expected k(x, x) to be equal to 1.0, got {}", kxx);
    }

    #[test]
    fn test_fb_kernel_xy() {
        let kernel = RBFKernel::new(1.0);
        let x = array![1.0, 2.0, 3.0];
        let y = array![4.0, 5.0, 6.0];
        let kxy = kernel.compute(x.view(), y.view());
        assert!(kxy < 1.0, "Expected k(x, y) < 1.0, got {}", kxy);
    }
}
}

The Kernel trait

We first define a Kernel trait:

#![allow(unused)]
fn main() {
pub trait Kernel {
    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64;
}
}

This trait represents a generic kernel function. It requires a single method, compute, which takes two inputs x and y as one-dimensional views (ArrayView1<f64>) and returns a scalar similarity score of type f64. By using views instead of owned arrays, this interface avoids unnecessary data copying and supports efficient evaluation.

This trait enables polymorphism: any kernel function that implements Kernel can be used within the rest of the KRR pipeline.

The RBFKernel struct

To provide a concrete implementation of the Kernel trait, the module defines the RBFKernel struct:

#![allow(unused)]
fn main() {
#[derive(Clone)]
pub struct RBFKernel {
    pub lengthscale: f64,
}
}

We will need the Clone trait later on for our cross validation technique.

The lengthscale parameter controls how quickly the similarity between two points decays with distance. A smaller lengthscale produces more localized kernels, while a larger one results in smoother, more global effects.

The constructor new is implemented as:

#![allow(unused)]
fn main() {
impl RBFKernel {
    pub fn new(lengthscale: f64) -> Self {
        assert!(lengthscale > 0.0, "Lengthscale must be positive");
        Self { lengthscale }
    }
}
}

This method ensures that the lengthscale is strictly positive, preventing ill-posed kernel evaluations.

Kernel evaluation

The Kernel trait is implemented for RBFKernel as follows:

#![allow(unused)]
fn main() {
impl Kernel for RBFKernel {
    fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
        let diff = &x - &y;
        (-diff.dot(&diff) / (2.0 * self.lengthscale.powi(2))).exp()
    }
}
}

This implementation computes the squared Euclidean distance between x and y, scales it by the squared lengthscale, and applies the exponential function. The result is the value of the Gaussian kernel:

This function satisfies the requirements of a positive definite kernel and is commonly used in many kernel-based algorithms.

Unit tests

The module includes two unit tests that validate the behavior of the RBF kernel:

#![allow(unused)]
fn main() {
#[test]
fn test_rbf_kernel_xx() {
    let kernel = RBFKernel::new(1.0);
    let x = array![1.0, 2.0, 3.0];
    let kxx = kernel.compute(x.view(), x.view());
    assert_eq!(kxx, 1.0, "Expected k(x, x) to be equal to 1.0, got {}", kxx);
}
}

This test checks that the kernel evaluated at the same point yields 1.0, as expected from the RBF formula.

#![allow(unused)]
fn main() {
#[test]
fn test_fb_kernel_xy() {
    let kernel = RBFKernel::new(1.0);
    let x = array![1.0, 2.0, 3.0];
    let y = array![4.0, 5.0, 6.0];
    let kxy = kernel.compute(x.view(), y.view());
    assert!(kxy < 1.0, "Expected k(x, y) < 1.0, got {}", kxy);
}
}

This test confirms that the similarity between two distinct vectors is strictly less than 1.0, reflecting the decay property of the RBF kernel.

Summary

The kernel.rs module introduces a reusable kernel interface and demonstrates a concrete implementation using the RBF kernel. It serves as a foundation for computing Gram matrices and enables modularity in the design of the KRR model. The use of traits and parametric polymorphism makes it easy to experiment with other kernel functions in future extensions.

Gram matrix

Once the Kernel trait is defined and implemented, it can be used to construct the Gram matrix required for kernel ridge regression. The Gram matrix contains all pairwise kernel evaluations between training inputs. It is symmetric by definition, since for common kernels such as the RBF.

In the KRR implementation, the Gram matrix is computed as follows:

#![allow(unused)]
fn main() {
use crate::kernel::RBFKernel;
use ndarray::{Array, Array2};

let n: usize = y_train.len();
let mut k_train: Array2<f64> = Array::zeros((n, n));
let kernel: RBFKernel = RBFKernel::new(1.0);

for i in 0..n {
    for j in 0..=i {
        let kxy: f64 = kernel.compute(x_train.row(i), x_train.row(j));
        k_train[(i, j)] = kxy;
        k_train[(j, i)] = kxy;
    }
}
}

Here, x_train.row(i) returns a value of type ArrayView1<f64>, which is exactly the type expected by the Kernel::compute method. The loop only computes the lower triangle of the matrix and mirrors it to the upper triangle to avoid redundant computation, exploiting the symmetry of the kernel. This approach is efficient and idiomatic in Rust using ndarray.

We will use this piece of code in our fit function.

KRR Model

This section describes the definition of the KRRModel struct and its constructor.

Click here to view the full model: model.rs.
#![allow(unused)]
fn main() {
use crate::errors::{KRRFitError, KRRPredictError};
use crate::kernel::Kernel;
use ndarray::{Array, Array1, Array2};
use ndarray_linalg::Solve;

//ANCHOR: KRRModel_struct
pub struct KRRModel<K: Kernel> {
    pub kernel: K,
    pub lambda: f64,
    pub x_train: Option<Array2<f64>>,
    pub alpha: Option<Array1<f64>>,
}
//ANCHOR_END: KRRModel_struct

//ANCHOR: KRRModel_new
impl<K: Kernel> KRRModel<K> {
    pub fn new(kernel: K, lambda: f64) -> Self {
        Self {
            kernel,
            lambda,
            x_train: None,
            alpha: None,
        }
    }
}
//ANCHOR_END: KRRModel_new

//ANCHOR: fit_function
impl<K: Kernel> KRRModel<K> {
    fn _fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError> {
        let n: usize = y_train.len();
        let mut k_train: Array2<f64> = Array::zeros((n, n));
        for i in 0..n {
            for j in 0..=i {
                let kxy = self.kernel.compute(x_train.row(i), x_train.row(j));
                k_train[(i, j)] = kxy;
                k_train[(j, i)] = kxy;
            }
        }

        let identity_n = Array2::eye(n);
        let a: Array2<f64> = k_train + self.lambda * identity_n;
        let alpha = a
            .solve_into(y_train)
            .map_err(|e| KRRFitError::LinAlgError(e.to_string()))?;

        self.x_train = Some(x_train);
        self.alpha = Some(alpha);

        Ok(())
    }

    pub fn fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError> {
        let n: usize = x_train.nrows();
        let m: usize = y_train.len();

        if n != m {
            eprintln!("[KRR::fit] Shape mismatch: x_train has {n} rows, y_train has {m} elments");
            return Err(KRRFitError::ShapeMismatch { x_n: n, y_n: m });
        }

        match self._fit(x_train, y_train) {
            Ok(_) => {
                eprintln!("[KRR::fit] Model successfully fitted.");
                Ok(())
            }
            Err(e) => {
                eprintln!("[KRR::fit] Fitting failed: {e}");
                Err(e)
            }
        }
    }
}
//ANCHOR_END: fit_function

//ANCHOR: predict_function
impl<K: Kernel> KRRModel<K> {
    pub fn predict(&self, x_test: &Array2<f64>) -> Result<Array1<f64>, KRRPredictError> {
        let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
        let x_train = self.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;

        let n_train: usize = x_train.nrows();
        let n_test: usize = x_test.nrows();
        let mut y_pred: Array1<f64> = Array::zeros(n_test);

        for i in 0..n_test {
            for j in 0..n_train {
                let k_val = self.kernel.compute(x_train.row(j), x_test.row(i));
                y_pred[i] += alpha[j] * k_val;
            }
        }
        Ok(y_pred)
    }
}
//ANCHOR_END: predict_function

#[cfg(test)]
mod tests {
    use super::*;
    use crate::kernel::RBFKernel;
    use ndarray::array;

    #[test]
    fn test_krr_constructor() {
        let kernel = RBFKernel::new(1.0);
        let model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);

        assert_eq!(
            model.lambda, 1.0,
            "Expected lambda equal to 1.0, got {}",
            model.lambda
        );

        assert_eq!(
            model.kernel.lengthscale, 1.0,
            "Expected kernel lengthscale to be 1.0, got {}",
            model.kernel.lengthscale
        );
    }

    #[test]
    fn test_ok_fit_and_predict() {
        let kernel = RBFKernel::new(1.0);
        let mut model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);
        let x_train: Array2<f64> = array![[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]];
        let y_train: Array1<f64> = array![0.9, 0.6];

        let res = model.fit(x_train, y_train);
        assert!(res.is_ok());

        let x_test: Array2<f64> = array![[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]];
        let y_pred = model.predict(&x_test);
        assert!(y_pred.is_ok());
    }

    #[test]
    fn test_dim_mismatch() {
        let kernel = RBFKernel::new(1.0);
        let mut model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);
        let x_train: Array2<f64> = array![[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]];
        let y_train: Array1<f64> = array![0.9, 0.6, 0.9];

        let res = model.fit(x_train, y_train);
        assert!(res.is_err());
    }

    #[test]
    fn test_unfitted_predict_error_type() {
        use crate::errors::KRRPredictError;

        let kernel = RBFKernel::new(1.0);
        let model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);
        let x_test: Array2<f64> = array![[1.0, 2.0, 3.0]];

        let result = model.predict(&x_test);
        assert!(matches!(result, Err(KRRPredictError::NotFitted)));
    }
}
}

The KRRModel struct

#![allow(unused)]
fn main() {
pub struct KRRModel<K: Kernel> {
    pub kernel: K,
    pub lambda: f64,
    pub x_train: Option<Array2<f64>>,
    pub alpha: Option<Array1<f64>>,
}
}

The KRRModel struct represents a kernel ridge regression model parameterized by a kernel type K that implements the Kernel trait. It includes the following fields:

  • kernel: an instance of the kernel function to be used (e.g., RBF kernel).
  • lambda: the regularization parameter.
  • x_train: optional training inputs stored after fitting.
  • alpha: optional dual coefficients computed during fitting.

These fields are marked pub depending on whether they are exposed to the user.

The new method

The new method is a constructor for creating a new instance of the model. It takes a kernel instance and a regularization parameter as arguments, and initializes an unfitted model:

#![allow(unused)]
fn main() {
impl<K: Kernel> KRRModel<K> {
    pub fn new(kernel: K, lambda: f64) -> Self {
        Self {
            kernel,
            lambda,
            x_train: None,
            alpha: None,
        }
    }
}
}

Unit test

The test_krr_constructor unit test validates that the constructor sets the lambda and kernel lengthscale fields correctly:

#![allow(unused)]
fn main() {
#[test]
fn test_krr_constructor() {
    let kernel = RBFKernel::new(1.0);
    let model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);

    assert_eq!(model.lambda, 1.0);
    assert_eq!(model.kernel.lengthscale, 1.0);
}
}

Fit function

This section describes the implementation of the fit function, which prepares the model for prediction by solving the kernel ridge regression problem.

The full code is given below and we break it down in the sequel of this section.

#![allow(unused)]
fn main() {
impl<K: Kernel> KRRModel<K> {
    fn _fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError> {
        let n: usize = y_train.len();
        let mut k_train: Array2<f64> = Array::zeros((n, n));
        for i in 0..n {
            for j in 0..=i {
                let kxy = self.kernel.compute(x_train.row(i), x_train.row(j));
                k_train[(i, j)] = kxy;
                k_train[(j, i)] = kxy;
            }
        }

        let identity_n = Array2::eye(n);
        let a: Array2<f64> = k_train + self.lambda * identity_n;
        let alpha = a
            .solve_into(y_train)
            .map_err(|e| KRRFitError::LinAlgError(e.to_string()))?;

        self.x_train = Some(x_train);
        self.alpha = Some(alpha);

        Ok(())
    }

    pub fn fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError> {
        let n: usize = x_train.nrows();
        let m: usize = y_train.len();

        if n != m {
            eprintln!("[KRR::fit] Shape mismatch: x_train has {n} rows, y_train has {m} elments");
            return Err(KRRFitError::ShapeMismatch { x_n: n, y_n: m });
        }

        match self._fit(x_train, y_train) {
            Ok(_) => {
                eprintln!("[KRR::fit] Model successfully fitted.");
                Ok(())
            }
            Err(e) => {
                eprintln!("[KRR::fit] Fitting failed: {e}");
                Err(e)
            }
        }
    }
}
}

_fit and fit methods

The _fit method performs the main computation:

  1. Computes the Gram matrix using the kernel.
  2. Adds regularization to the diagonal.
  3. Solves the linear system for the dual coefficients.

The public fit method wraps _fit and performs input validation. It checks that the dimensions of x_train and y_train match, and logs messages about success or failure.

The signature of the fit function is given by:

#![allow(unused)]
fn main() {
pub fn fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError>
}

Before looking at fit and _fit, we need to define the enum KRRFitError.

KRRFitError enum

The KRRFitError enum defines two error types:

  • ShapeMismatch: occurs when the number of samples in x_train and y_train do not match.
  • LinAlgError: returned if solving the linear system fails.

This enum is used to cleanly propagate and format error messages via Result. This enum is implement thanks to thiserror as follows:

#![allow(unused)]
fn main() {
use thiserror::Error;

#[derive(Debug, Error)]
pub enum KRRFitError {
    #[error("Shape mismatch: x has {x_n} rows but y has {y_n} elements")]
    ShapeMismatch { x_n: usize, y_n: usize },

    #[error("Solving the linear system failed")]
    LinAlgError(String),
}
}

The _fit function

The _fit function computes the dual coefficients by solving the linear system:

Here, and represent the training data, is the Gram matrix computed from the kernel, is the regularization parameter, and is the identity matrix of size .

The function proceeds as follows:

  • It first computes the symmetric Gram matrix and stores it in the variable k_train.
  • It constructs the left-hand side matrix .
  • It solves the resulting linear system for using the solve_into() method.
  • Finally, it stores x_train and the computed alpha inside the model. Keeping x_train is essential for future predictions on new inputs.

The full function is shown below. You can try to spot interesting stuff that we haven't mentioned yet. We make a few additional comments afterwards.

#![allow(unused)]
fn main() {
fn _fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError> {
    let n: usize = y_train.len();
    let mut k_train: Array2<f64> = Array::zeros((n, n));
    for i in 0..n {
        for j in 0..=i {
            let kxy = self.kernel.compute(x_train.row(i), x_train.row(j));
            k_train[(i, j)] = kxy;
            k_train[(j, i)] = kxy;
        }
    }

    let identity_n = Array2::eye(n);
    let a: Array2<f64> = k_train + self.lambda * identity_n;
    let alpha = a
        .solve_into(y_train)
        .map_err(|e| KRRFitError::LinAlgError(e.to_string()))?;

    self.x_train = Some(x_train);
    self.alpha = Some(alpha);

    Ok(())
}
}

Additional notes:

  • x_train and y_train are not passed by reference and are therefore moved into the _fit function. This is fine because we do not need them afterward, and x_train is stored into self.x_train at the end of the function.

  • The method x_train.row(i) extracts the i-th row of the training matrix as an ArrayView1<f64>, which is exactly the input type expected by our kernel.compute method.

  • The line that computes alpha ends with a question mark ?, which is Rust syntax for propagating errors. If solve_into() fails (for instance, due to an ill-conditioned matrix), the function returns early with a KRRFitError::LinAlgError. If it succeeds, the result is assigned to alpha, and we continue toward returning Ok(()), consistent with the declared return type Result<(), KRRFitError>.

  • x_train and alpha are wrapped in Some(...) because the fields in the KRRModel struct are declared as Option.

The fit function

The fit function serves as the public interface for training the model. It takes ownership of the training data and performs validation before delegating the actual computation to the private _fit method. Its signature is:

#![allow(unused)]
fn main() {
pub fn fit(&mut self, x_train: Array2<f64>, y_train: Array1<f64>) -> Result<(), KRRFitError> {
    let n: usize = x_train.nrows();
    let m: usize = y_train.len();

    if n != m {
        eprintln!("[KRR::fit] Shape mismatch: x_train has {n} rows, y_train has {m} elements");
        return Err(KRRFitError::ShapeMismatch { x_n: n, y_n: m });
    }

    match self._fit(x_train, y_train) {
        Ok(_) => {
            eprintln!("[KRR::fit] Model successfully fitted.");
            Ok(())
        }
        Err(e) => {
            eprintln!("[KRR::fit] Fitting failed: {e}");
            Err(e)
        }
    }
}
}

Here's how it works step-by-step:

  • It extracts the number of training samples in x_train (n) and compares it to the number of targets in y_train (m).

  • If these sizes do not match, it logs a message and returns a KRRFitError::ShapeMismatch variant. This early return prevents proceeding with inconsistent inputs.

  • If the shapes are consistent, the function calls the _fit method to perform the actual kernel ridge regression fitting.

  • It logs whether the fitting was successful or not, and returns a Result accordingly.

This design separates concerns:

  • fit is responsible for input checking and logging,
  • _fit performs the mathematical computations.

This modular approach makes it easier to write clean tests, and to report errors in a structured and maintainable way.

Unit tests

The test_ok_fit_and_predict test verifies that a valid fit and prediction workflow runs without errors.

#![allow(unused)]
fn main() {
#[test]
fn test_ok_fit_and_predict() {
    let kernel = RBFKernel::new(1.0);
    let mut model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);
    let x_train: Array2<f64> = array![[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]];
    let y_train: Array1<f64> = array![0.9, 0.6];

    let res = model.fit(x_train, y_train);
    assert!(res.is_ok());

    let x_test: Array2<f64> = array![[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]];
    let y_pred = model.predict(&x_test);
    assert!(y_pred.is_ok());
}
}

The test_dim_mismatch test confirms that the model returns an appropriate error when input shapes are inconsistent:

#![allow(unused)]
fn main() {
#[test]
fn test_dim_mismatch() {
    let x_train = array![[1.0, 2.0], [3.0, 4.0]];
    let y_train = array![1.0, 2.0, 3.0];
    let res = model.fit(x_train, y_train);
    assert!(res.is_err());
}
}

Predict function

This section describes the implementation of the predict method, which uses the trained model to make predictions on new inputs. The full code is given below, and we break it down in the sequel of this section.

#![allow(unused)]
fn main() {
impl<K: Kernel> KRRModel<K> {
    pub fn predict(&self, x_test: &Array2<f64>) -> Result<Array1<f64>, KRRPredictError> {
        let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
        let x_train = self.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;

        let n_train: usize = x_train.nrows();
        let n_test: usize = x_test.nrows();
        let mut y_pred: Array1<f64> = Array::zeros(n_test);

        for i in 0..n_test {
            for j in 0..n_train {
                let k_val = self.kernel.compute(x_train.row(j), x_test.row(i));
                y_pred[i] += alpha[j] * k_val;
            }
        }
        Ok(y_pred)
    }
}
}

It takes a reference to a two-dimensional test array and returns either the predicted values as a one-dimensional array or an error if the model is not yet fitted.

Logic of the predict method

The function first checks whether the model has been fitted. This involves verifying that the fields self.alpha and self.x_train are set (i.e., not None). If either is missing, the function returns a KRRPredictError::NotFitted variant.

#![allow(unused)]
fn main() {
let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
let x_train = self.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;
}

The actual prediction then proceeds by computing the kernel value between each training point and each test point:

#![allow(unused)]
fn main() {
for i in 0..n_test {
    for j in 0..n_train {
        let k_val = self.kernel.compute(x_train.row(j), x_test.row(i));
        y_pred[i] += alpha[j] * k_val;
    }
}
}

This implements the inference equation:

where are the training samples, are the learned dual coefficients, and is the kernel function.

KRRPredictError

The KRRPredictError is an enum used to indicate that the model has not been fitted yet. It is defined in the errors.rs module using the thiserror crate:

#![allow(unused)]
fn main() {
#[derive(Debug, Error)]
pub enum KRRPredictError {
    #[error("Model not fitted")]
    NotFitted,
}
}

This enum allows the predict function to return a Result type, making error propagation idiomatic and clean.

Use of ? operator

The function uses the ? operator to simplify error handling. For example:

#![allow(unused)]
fn main() {
let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
}

This line either extracts the alpha reference if it exists or returns early with an error. This pattern keeps the code concise and expressive.

Unit tests

The test_unfitted_predict_error_type unit test checks that the correct error is returned when attempting to call predict before fitting the model:

#![allow(unused)]
fn main() {
#[test]
fn test_unfitted_predict_error_type() {
    use crate::errors::KRRPredictError;

    let kernel = RBFKernel::new(1.0);
    let model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);
    let x_test: Array2<f64> = array![[1.0, 2.0, 3.0]];

    let result = model.predict(&x_test);
    assert!(matches!(result, Err(KRRPredictError::NotFitted)));
}
}

Hyperparameter tuning with LOO-CV

This section focuses on hyperparameter selection for tuning the kernel lengthscale using Leave-One-Out Cross-Validation (LOO-CV). We implement two key functions in the model_selection.rs module:

  • loo_cv_error: computes the LOO-CV error for a given model and training dataset.
  • tune_lengthscale: evaluates multiple candidate lengthscales and returns the one with the lowest LOO-CV error.

Leave-One-Out Cross-Validation (LOO-CV)

LOO-CV is a classical cross-validation strategy in which each data point is left out once as a test sample while the model is trained on the remaining samples. The final error is the average squared difference between the predicted and true values.

In KRR, thanks to the closed-form solution, the LOO-CV error can be computed efficiently without retraining the model times. The formula is:

Where:

  • is the hat matrix,
  • is the prediction on the training data,
  • is the i-th diagonal element of the hat matrix.

The loo_cv_error function implements this logic:

#![allow(unused)]
fn main() {
pub fn loo_cv_error<K: Kernel>(model: &KRRModel<K>) -> Result<f64, KRRPredictError> {
    let alpha = model.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
    let x_train = model.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;

    let n = x_train.nrows();
    let mut k_train = Array2::zeros((n, n));

    for i in 0..n {
        for j in 0..=i {
            let kxy = model.kernel.compute(x_train.row(i), x_train.row(j));
            k_train[(i, j)] = kxy;
            k_train[(j, i)] = kxy;
        }
    }

    let identity_n = Array2::eye(n);
    let a = k_train + model.lambda * identity_n;
    let a_inv = a.inv().expect("Inversion failed");

    let mut loo_error = 0.0;
    for i in 0..n {
        let ai = alpha[i];
        let di = a_inv[(i, i)];
        let res = ai / di;
        loo_error += res.powi(2);
    }

    Ok(loo_error / (n as f64))
}
}

It returns the mean squared error over the training set based on the LOO-CV formula.

Note that we got a bit lazy here:

  • We use the .expect() method to raise an exception if the inversion fails. This will make the code crash instead of making the function return an error, like we did with our KRRFitError and KRRPredictError enums.
  • We re-compute the Gram matrix whereas it could be stored within the KRRModel like we did for alpha and x_train.

Tuning the lengthscale

We now want to search for the optimal lengthscale of the RBF kernel that minimizes the LOO-CV error. The function:

#![allow(unused)]
fn main() {
pub fn tune_lengthscale<K: Kernel + Clone>(
    x_train: Array2<f64>,
    y_train: Array1<f64>,
    lambda: f64,
    lengthscales: &[f64],
    kernel_builder: impl Fn(f64) -> K,
) -> Result<(K, f64), String> {
    let mut best_error = f64::INFINITY;
    let mut best_kernel = None;

    for &l in lengthscales {
        let kernel = kernel_builder(l);
        let mut model = KRRModel::new(kernel.clone(), lambda);

        if model.fit(x_train.clone(), y_train.clone()).is_err() {
            continue;
        }

        if let Ok(err) = loo_cv_error(&model)
            && err < best_error
        {
            best_error = err;
            best_kernel = Some(kernel);
        }
    }

    best_kernel
        .map(|k| (k, best_error))
        .ok_or("Tuning failed".to_string())
}
}

takes in a list of candidate lengthscales, fits a model for each, and selects the one with the lowest LOO-CV error. Internally, it uses:

  1. The RBFKernel struct to instantiate kernels with varying lengthscales.
  2. The KRRModel::fit method to train each model.
  3. The loo_cv_error function to evaluate them.

Unit test

The test_tune_lengthscale test verifies that the tuning function works correctly:

#![allow(unused)]
fn main() {
#[test]
fn test_tune_lengthscale() {
    let x_train = array![[0.0], [1.0], [2.0]];
    let y_train = array![0.0, 1.0, 2.0];
    let candidates = vec![0.01, 0.1, 1.0, 10.0];

    let best = tune_lengthscale(x_train, y_train, &candidates);
    assert!(candidates.contains(&best));
}
}

This test confirms that the selected best value is indeed one of the candidate lengthscales.

Poisson 2D solver in Rust

This crate implements a simple 2D finite element solver for the Poisson equation:

where:

  • is a 2D domain discretized into finite elements (triangles or quadrangles),
  • is a given source function,
  • is a Dirichlet boundary condition.

Weak formulation

To solve the PDE using the finite element method, we first write its weak form:

Find such that:

where is the closure of the usual Sobolev space in which we seek our solution. We then discretize the problem using finite element basis functions, leading to a linear system:

with:

  • A: global stiffness matrix (assembled from element contributions),
  • u: vector of nodal values,
  • b: load vector from the source term.

The following features are considered:

  • Support for different element types (P1, Q1)
  • Dense and sparse matrix assembly
  • Dirichlet boundary condition handling

We rely on nalgebra for linear algebra as it provides good support for dense and sparse matrices.

Code Structure

At the end of the chapter, we obtain a small standalone crate with the following layout:

├── Cargo.toml
└── src
    ├── element.rs
    ├── lib.rs
    ├── mesh.rs
    ├── quadrature.rs
    └── solver.rs

The crate is split into the following modules:

  • element.rs: Defines finite element types and related data structures (e.g., connectivity, local stiffness).

  • mesh.rs: Defines the Mesh2d structure, storing:

    • Vertex coordinates
    • Element connectivity
    • Element type

    Also provides accessors and utility methods for FEM assembly.

  • quadrature.rs: Implements quadrature (numerical integration) rules for computing element matrices.

  • solver.rs: Core numerical routines:

    • System assembly (dense & sparse versions)
    • Dirichlet boundary condition application
    • Linear system solver
  • lib.rs: Crate root where we re-export the main types and functions for easier use.

Example

use poisson_2d::{solve_poisson_2d, Mesh2d, SolverType};
use poisson_2d::mesh::{Element, ElementType};
use nalgebra::{Point2, DVector};

fn main() {
    // Build a tiny unit-square mesh with one Q1 quad (4 nodes, 1 element)
    let vertices = vec![
        Point2::new(0.0, 0.0), // 0
        Point2::new(1.0, 0.0), // 1
        Point2::new(1.0, 1.0), // 2
        Point2::new(0.0, 1.0), // 3
    ];
    // Only a single quad element
    let elements = vec![
        Element { indices: vec![0, 1, 2, 3] }
    ];
    let mesh = Mesh2d::new(vertices, elements, ElementType::Q1);

    // Define boundary nodes and functions
    let boundary_nodes = vec![0, 1];

    // Dirichlet boundary g(x,y) = 0
    let g = |_: f64, _: f64| 0.0;

    // Source term f(x,y) = x + y
    let f = |x: f64, y: f64| x + y;

    // Solve (choose Dense or Sparse)
    let u_dense: DVector<f64> =
        solve_poisson_2d(&mesh, &boundary_nodes, &g, &f, SolverType::Dense);

    let u_sparse: DVector<f64> =
        solve_poisson_2d(&mesh, &boundary_nodes, &g, &f, SolverType::Sparse);
}

User interface

In this chapter, we start the other way around, and first have a look at the proposed user interfance. The full file lib.rs is provided below. It mostly contains an enum and a helper function for solving the 2D Poisson problem.

Click here to view: lib.rs.
#![allow(unused)]
fn main() {
//! This file is part of the Poisson 2D crate.
//!! It provides functionality for solving the 2D Poisson equation using finite element methods (FEM).
//!
//! The crate includes modules for elements, mesh, quadrature rules, and solvers.

pub mod element;
pub mod mesh;
pub mod quadrature;
pub mod solver;

pub use solver::{assemble_and_solve_dense, assemble_and_solve_sparse};

pub use mesh::Mesh2d;
pub use nalgebra::DVector;

/// Enum representing the type of solver to use
// ANCHOR: solver_type
pub enum SolverType {
    Dense,
    Sparse,
}
// ANCHOR_END: solver_type

/// Helper function for solving the 2D Poisson problem
///
/// This function takes a mesh, boundary nodes, boundary function, source function, and solver type.
/// Arguments:
/// - `mesh`: The mesh representing the domain.
/// - `boundary_nodes`: Indices of the nodes on the boundary.
/// - `boundary_fn`: Function defining the boundary condition.
/// - `source_fn`: Function defining the source term.
/// - `solver_type`: Type of solver to use (Dense or Sparse).
///
/// Returns:
/// - A vector containing the solution at the mesh nodes.
// ANCHOR: solve_poisson_2d
pub fn solve_poisson_2d<F>(
    mesh: &Mesh2d,
    boundary_nodes: &[usize],
    boundary_fn: &F,
    source_fn: &F,
    solver_type: SolverType,
) -> DVector<f64>
where
    F: Fn(f64, f64) -> f64,
{
    match solver_type {
        SolverType::Dense => assemble_and_solve_dense(mesh, boundary_nodes, boundary_fn, source_fn),
        SolverType::Sparse => {
            assemble_and_solve_sparse(mesh, boundary_nodes, boundary_fn, source_fn)
        }
    }
}
// ANCHOR_END: solve_poisson_2d
}

Solver type

In this example, we consider two types of finite element solvers:

  • A dense solver which assembles dense matrices and uses a dense solver for the linear system.
  • A sparse solver which assembles sparse matrices and uses a sparse solver for the linear system.

This is encoded by defining the following enumerate:

#![allow(unused)]
fn main() {
pub enum SolverType {
    Dense,
    Sparse,
}
}

It can be used to pick the solver type by passing SolverType::Dense or SolverType::Sparse to the helper function discussed below.

Helper function

The helper function that the user should call is shown below. It takes the following input arguments:

  • A mesh, given as a Mesh2d struct implemented in the mesh.rs module.
  • The boundary nodes and the boundary function () for applying Dirichlet boundary conditions.
  • The source function .
  • And the solver type (SolverType::Dense or SolverType::Sparse).

Based on the chosen solver type, the function either calls the dense or sparse methods thanks to pattern matching.

#![allow(unused)]
fn main() {
pub fn solve_poisson_2d<F>(
    mesh: &Mesh2d,
    boundary_nodes: &[usize],
    boundary_fn: &F,
    source_fn: &F,
    solver_type: SolverType,
) -> DVector<f64>
where
    F: Fn(f64, f64) -> f64,
{
    match solver_type {
        SolverType::Dense => assemble_and_solve_dense(mesh, boundary_nodes, boundary_fn, source_fn),
        SolverType::Sparse => {
            assemble_and_solve_sparse(mesh, boundary_nodes, boundary_fn, source_fn)
        }
    }
}
}

Mesh module

The mesh module defines the mesh data structure used throughout the Poisson 2D solver.
It encapsulates the geometrical discretization of the domain, i.e., the set of vertices (points in space) and the list of finite elements that connect those vertices.
Any details related to the definition of an element—such as connectivity, shape functions, or element types are implemented in element.rs and simply referenced here.

Mesh struct

The main struct in this module is Mesh2d. It holds:

  • vertices: the coordinates of all mesh nodes as a Vec<Point2<f64>>. A vector (from the std lib) of Point2<f64> vectors (from nalgebra),
  • elements: the list of finite elements as Vec<Element>,
  • element_type: an ElementType enum indicating the type of all elements in the mesh (e.g., P1, Q1).

The Element struct and ElementType enum are defined in the element.rs module.

#![allow(unused)]
fn main() {
#[derive(Clone, Debug)]
pub struct Mesh2d {
    vertices: Vec<Point2<f64>>,
    elements: Vec<Element>,
    element_type: ElementType,
}
}

This struct is marked with #[derive(Clone, Debug)] to allow duplication and debug printing, which are useful for testing and inspecting the mesh.

Mesh implementations

The implementation block provides:

  • new(...): a constructor that takes ownership of the vertex list, element list, and element type.
  • vertices(&self): returns an immutable slice of the mesh's vertices.
  • elements(&self): returns an immutable slice of the mesh's elements.
  • element_type(&self): returns a reference to the mesh's ElementType.
#![allow(unused)]
fn main() {
impl Mesh2d {
    pub fn new(
        vertices: Vec<Point2<f64>>,
        elements: Vec<Element>,
        element_type: ElementType,
    ) -> Self {
        Self {
            vertices,
            elements,
            element_type,
        }
    }
    pub fn vertices(&self) -> &[Point2<f64>] {
        &self.vertices
    }

    pub fn elements(&self) -> &[Element] {
        &self.elements
    }

    pub fn element_type(&self) -> &ElementType {
        &self.element_type
    }
}
}

These accessor methods are intentionally read-only, ensuring the internal structure of the mesh cannot be mutated from outside without explicit intent.

A simple unit test

The module includes a basic unit test to verify that:

  • The mesh stores the correct number of vertices and elements,
  • The element_type is stored and accessible correctly.

The test builds a simple unit square mesh with four vertices and one Q1 element, then asserts the expected sizes and type.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_mesh2d() {
        let vertices = vec![
            Point2::new(0.0, 0.0),
            Point2::new(1.0, 0.0),
            Point2::new(1.0, 1.0),
            Point2::new(0.0, 1.0),
        ];
        let elements = vec![Element {
            indices: vec![0, 1, 2, 3],
        }];
        let mesh = Mesh2d {
            vertices,
            elements,
            element_type: ElementType::Q1,
        };

        assert_eq!(mesh.vertices().len(), 4);
        assert_eq!(mesh.elements().len(), 1);
        assert_eq!(*mesh.element_type(), ElementType::Q1);
    }
}
}

Element module

The element module defines the finite element types and the associated reference elements used by the solver. It is responsible for everything specific to elements: types, connectivity, shape functions, gradients, and the jacobian needed to map between reference and physical coordinates.

This module is consumed by higher-level parts of the crate (mesh, quadrature, solver).

Element struct and types

The crate currently supports two classical 2D elements:

  • P1: 3-node linear triangle (Tri3),
  • Q1: 4-node bilinear quadrilateral (Quad4).

ElementType encodes which type is used, and Element stores the connectivity via global node indices.

#![allow(unused)]
fn main() {
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ElementType {
    /// 3-node triangle
    P1,
    /// 4-node quadrangle
    Q1,
}

/// An element stores a vector containing its global indices.
#[derive(Clone, Debug)]
pub struct Element {
    pub indices: Vec<usize>,
}
}

Notes

  • ElementType derives Clone, PartialEq, Eq, and Debug, which makes pattern matching and testing straightforward.
  • Element is a tiny container that holds the indices of the mesh vertices forming each element: geometry (actual coordinates) lives in the mesh.

Reference element enum

The reference element encodes the canonical (parameter-space) version of each element type:

  • Tri3: the unit reference triangle,
  • Quad4: the unit reference square .

The method num_nodes() returns the number of nodes for each reference element.

#![allow(unused)]
fn main() {
#[derive(Debug, Clone)]
pub enum ReferenceElement {
    /// 3-node reference triangle
    Tri3,
    /// 4-node reference quadrangle
    Quad4,
}

impl ReferenceElement {
    pub fn num_nodes(&self) -> usize {
        match self {
            ReferenceElement::Tri3 => 3,
            ReferenceElement::Quad4 => 4,
        }
    }
}
}

This abstraction separates element formulas (defined on reference coordinates) from the actual geometry (physical coordinates from the mesh).

Reference element implementations

This block implements the core finite element kinematics on the reference element:

  • shape_functions(&Point2<f64>) -> Vec<f64>
    Returns the values of the shape functions at given local coordinates .
    These are used to interpolate fields (e.g., ) inside an element:

  • shape_gradients(&Point2<f64>) -> Vec<Vector2<f64>>
    Returns the gradients on the reference element, i.e., .
    They are combined with the inverse Jacobian to obtain physical gradients during assembly.

  • jacobian(vertices_coordinates, local_coordinates) -> Matrix2<f64>
    Computes the Jacobian of the mapping from reference to physical coordinates.
    For Tri3 this reduces to a constant matrix built from vertex differences; for Quad4 it is obtained by summing contributions of the shape function gradients weighted by vertex coordinates.

#![allow(unused)]
fn main() {
impl ReferenceElement {
    pub fn shape_functions(&self, local_coordinates: &Point2<f64>) -> Vec<f64> {
        match self {
            ReferenceElement::Tri3 => {
                let xi = local_coordinates.x;
                let eta = local_coordinates.y;
                vec![1.0 - xi - eta, xi, eta]
            }
            ReferenceElement::Quad4 => {
                let xi = local_coordinates.x;
                let eta = local_coordinates.y;
                let n1 = 0.25 * (1.0 - xi) * (1.0 - eta);
                let n2 = 0.25 * (1.0 + xi) * (1.0 - eta);
                let n3 = 0.25 * (1.0 + xi) * (1.0 + eta);
                let n4 = 0.25 * (1.0 - xi) * (1.0 + eta);
                vec![n1, n2, n3, n4]
            }
        }
    }

    pub fn shape_gradients(&self, local_coordinates: &Point2<f64>) -> Vec<Vector2<f64>> {
        match self {
            ReferenceElement::Tri3 => {
                vec![
                    Vector2::new(-1.0, 1.0),
                    Vector2::new(1.0, 0.0),
                    Vector2::new(0.0, 1.0),
                ]
            }
            ReferenceElement::Quad4 => {
                let xi = local_coordinates.x;
                let eta = local_coordinates.y;
                let dn1_dxi = -0.25 * (1.0 - eta);
                let dn1_deta = -0.25 * (1.0 - xi);
                let dn2_dxi = 0.25 * (1.0 - eta);
                let dn2_deta = -0.25 * (1.0 + xi);
                let dn3_dxi = 0.25 * (1.0 + eta);
                let dn3_deta = 0.25 * (1.0 + xi);
                let dn4_dxi = -0.25 * (1.0 + eta);
                let dn4_deta = 0.25 * (1.0 - xi);
                vec![
                    Vector2::new(dn1_dxi, dn1_deta),
                    Vector2::new(dn2_dxi, dn2_deta),
                    Vector2::new(dn3_dxi, dn3_deta),
                    Vector2::new(dn4_dxi, dn4_deta),
                ]
            }
        }
    }

    pub fn jacobian(
        &self,
        vertices_coordinates: &[Point2<f64>],
        local_coordinates: &Point2<f64>,
    ) -> Matrix2<f64> {
        match self {
            ReferenceElement::Tri3 => {
                let v0 = vertices_coordinates[0];
                let v1 = vertices_coordinates[1];
                let v2 = vertices_coordinates[2];
                let dx_dxi = v1.x - v0.x;
                let dy_dxi = v1.y - v0.y;
                let dx_deta = v2.x - v0.x;
                let dy_deta = v2.y - v0.y;
                Matrix2::new(dx_dxi, dx_deta, dy_dxi, dy_deta)
            }
            ReferenceElement::Quad4 => {
                let grads = self.shape_gradients(local_coordinates);
                let mut jac = Matrix2::zeros();
                for (grad, vertex) in grads.iter().zip(vertices_coordinates.iter()) {
                    jac[(0, 0)] += grad.x * vertex.x;
                    jac[(1, 0)] += grad.x * vertex.y;
                    jac[(0, 1)] += grad.y * vertex.x;
                    jac[(1, 1)] += grad.y * vertex.y;
                }
                jac
            }
        }
    }
}
}

How it fits into assembly:

  1. Evaluate shape functions (and gradients) at quadrature points in reference space.
  2. Build the Jacobian and its determinant .
  3. Map reference gradients to physical gradients via .
  4. Accumulate local stiffness and load contributions.

Simple unit test

This smoke test checks that:

  • the number of nodes reported by each reference element is correct,
  • the shape function vectors have matching sizes.
#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_reference_element() {
        let tri3 = ReferenceElement::Tri3;
        let quad4 = ReferenceElement::Quad4;

        assert_eq!(tri3.num_nodes(), 3);
        assert_eq!(quad4.num_nodes(), 4);

        let local_coords = Point2::new(0.5, 0.5);
        let tri_shape_funcs = tri3.shape_functions(&local_coords);
        let quad_shape_funcs = quad4.shape_functions(&local_coords);

        assert_eq!(tri_shape_funcs.len(), 3);
        assert_eq!(quad_shape_funcs.len(), 4);
    }
}
}

Quadrature module

The quadrature module defines how we perform numerical integration over finite elements.
In the FEM assembly, element integrals like are approximated with quadrature rules (a set of points and weights) on the reference element. The mapping to physical space is handled elsewhere via the Jacobian.

Quadrature struct

A quadrature rule is defined by:

  • points: the evaluation points in reference coordinates, and
  • weights: the corresponding weights.
#![allow(unused)]
fn main() {
#[derive(Clone, Debug)]
pub struct QuadRule {
    pub points: Vec<Point2<f64>>,
    pub weights: Vec<f64>,
}
}

This lightweight container is used by the solver during element-level integration. For correctness, points.len() must equal weights.len().

Quadrature implementations

Two families of rules are provided:

  • triangle(order): simple rules on the reference triangle.

    • order = 1: 1-point rule (centroid), total weight which matches the reference triangle area.
    • order = 2: 3-point rule, exact for linear fields.
  • quadrilateral(n): tensor-product Gauss rules on the reference square .

    • n = 1: 1-point Gauss rule at the center with weight .
    • n = 2: (2\times2) Gauss rule; points at , each with weight .
#![allow(unused)]
fn main() {
impl QuadRule {
    pub fn triangle(order: usize) -> Self {
        match order {
            1 => QuadRule {
                points: vec![Point2::new(1.0 / 3.0, 1.0 / 3.0)],
                weights: vec![0.5],
            },
            2 => QuadRule {
                points: vec![
                    Point2::new(1.0 / 6.0, 1.0 / 6.0),
                    Point2::new(2.0 / 3.0, 1.0 / 6.0),
                    Point2::new(1.0 / 6.0, 2.0 / 3.0),
                ],
                weights: vec![1.0 / 6.0; 3],
            },
            _ => panic!("triangle quadratule of order > 2 not implemented"),
        }
    }

    pub fn quadrilateral(n: usize) -> Self {
        match n {
            1 => QuadRule {
                points: vec![Point2::new(0.0, 0.0)],
                weights: vec![4.0],
            },
            2 => {
                let a = 1.0 / 3.0f64.sqrt();
                let pts = [-a, a];
                let mut points = Vec::with_capacity(4);
                let mut weights = Vec::with_capacity(4);
                for xi in pts {
                    for eta in pts {
                        points.push(Point2::new(xi, eta));
                        weights.push(1.0);
                    }
                }
                QuadRule { points, weights }
            }
            _ => panic!("quadrilateral quadrature with n > 2 points not implemented"),
        }
    }
}
}

Notes

  • For triangles, the implementation panics for order > 2 and similarly, quads panic for n > 2. Extend these if you need higher-order elements or exactness.
  • The quadrilateral rule builds the 2D grid from the 1D Gauss abscissae when n = 2.

Simple tests

The tests check that the rule sizes match expectations for the provided configurations.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_triangle_quadrature() {
        let rule = QuadRule::triangle(2);
        assert_eq!(rule.points.len(), 3);
        assert_eq!(rule.weights.len(), 3);
    }

    #[test]
    fn test_quadrilateral_quadrature() {
        let rule = QuadRule::quadrilateral(2);
        assert_eq!(rule.points.len(), 4);
        assert_eq!(rule.weights.len(), 4);
    }
}
}

Solver module

This module contains all the core logic required to assemble and solve the finite element system for the 2D Poisson equation.
It supports both dense and sparse matrix formulations and includes routines for:

  • Assembling the stiffness matrix and load vector from a given mesh and source function.
  • Applying Dirichlet boundary conditions to either dense or sparse systems.
  • Solving the resulting linear system using either a Cholesky decomposition (dense) or the Conjugate Gradient method (sparse).
  • High-level “assemble-and-solve” functions for convenience.

We proceed in the same order as the implementation, explaining each part in detail.

Mathematical background

We consider the Poisson problem in its weak form:

where:

  • is the computational domain.
  • is the trial space satisfying Dirichlet boundary conditions.
  • is the test space with homogeneous Dirichlet conditions.
  • is the source term.

In the finite element method, the solution is expanded in terms of basis functions , and we obtain the linear system:

The stiffness matrix and load vector are computed by assembling contributions from each element in the mesh. Numerical integration is carried out using quadrature rules on a reference element, with transformations to the physical element through the Jacobian.

Dense system assembly

The first function, assemble_system_dense, constructs the stiffness matrix and right-hand side vector for the given mesh and source term using a dense matrix representation.
It:

  1. Selects the appropriate reference element (Tri3 or Quad4) based on the mesh element type.
  2. Chooses a second-order quadrature rule to integrate element-level matrices.
  3. Loops over each element, computes local stiffness ke and local load fe, and assembles them into the global system.
#![allow(unused)]
fn main() {
pub fn assemble_system_dense<F>(mesh: &Mesh2d, source_fn: &F) -> (DMatrix<f64>, DVector<f64>)
where
    F: Fn(f64, f64) -> f64,
{
    let num_vertices = mesh.vertices().len();
    let mut a = DMatrix::zeros(num_vertices, num_vertices);
    let mut b = DVector::zeros(num_vertices);

    // Pick the right reference element based on the element type in the mesh.
    let ref_element = match mesh.element_type() {
        ElementType::P1 => ReferenceElement::Tri3,
        ElementType::Q1 => ReferenceElement::Quad4,
    };

    // Pick the right quadrature rule based on the element type in the mesh.
    // We use second-order quadrature rules by default.
    let quad_rule = match mesh.element_type() {
        ElementType::P1 => QuadRule::triangle(2),
        ElementType::Q1 => QuadRule::quadrilateral(2),
    };

    let n: usize = ref_element.num_nodes();
    for element in mesh.elements() {
        // Get the coordinates of the element nodes
        let mut nodes: Vec<Point2<f64>> = Vec::with_capacity(n);
        for vid in &element.indices {
            let vertex = mesh.vertices()[*vid];
            nodes.push(vertex);
        }

        // Compute the local stiff and load vectors
        let mut ke = vec![vec![0.0; n]; n];
        let mut fe = vec![0.0; n];
        for (quad_points, quad_weights) in quad_rule.points.iter().zip(quad_rule.weights.iter()) {
            // Compute local quantities in the reference element
            let grads_ref = ref_element.shape_gradients(quad_points);
            let jac_ref = ref_element.jacobian(&nodes, quad_points);
            let det_jac_ref = jac_ref.determinant();
            let jac_inv_t = jac_ref.try_inverse().unwrap().transpose();

            // Compute gradient in the physical space
            let mut grads_global: Vec<Vector2<f64>> = Vec::with_capacity(n);
            for grad_ref in grads_ref {
                let grad = jac_inv_t * grad_ref;
                grads_global.push(grad);
            }

            // Evaluate physical coordinates of quadrature point
            let shape_vals = ref_element.shape_functions(quad_points);
            let mut x = 0.0;
            let mut y = 0.0;
            for (val, vtx) in shape_vals.iter().zip(&nodes) {
                x += val * vtx.x;
                y += val * vtx.y;
            }

            // Fill ke and fe
            let f_val = source_fn(x, y);
            let weight = quad_weights * det_jac_ref.abs();
            for i in 0..n {
                for j in 0..n {
                    ke[i][j] += grads_global[i].dot(&grads_global[j]) * weight;
                }
                fe[i] += shape_vals[i] * f_val * weight;
            }
        }

        // Assemble into global matrix/vector
        for (local_i, &global_i) in element.indices.iter().enumerate() {
            b[global_i] += fe[local_i];
            for (local_j, &global_j) in element.indices.iter().enumerate() {
                a[(global_i, global_j)] += ke[local_i][local_j];
            }
        }
    }

    (a, b)
}
}

Sparse system assembly

The assemble_system_sparse function is analogous to the dense version, but uses a COO format during assembly and converts to CSR format for efficiency.
This is better suited for large problems where most of the global matrix entries are zero.

#![allow(unused)]
fn main() {
pub fn assemble_system_sparse<F>(mesh: &Mesh2d, source_fn: &F) -> (CsrMatrix<f64>, DVector<f64>)
where
    F: Fn(f64, f64) -> f64,
{
    let num_vertices = mesh.vertices().len();
    let mut coo = CooMatrix::new(num_vertices, num_vertices);
    let mut b = DVector::zeros(num_vertices);

    // Pick the right reference element based on the element type in the mesh.
    let ref_element = match mesh.element_type() {
        ElementType::P1 => ReferenceElement::Tri3,
        ElementType::Q1 => ReferenceElement::Quad4,
    };

    // Pick the right quadrature rule based on the element type in the mesh.
    // We use second-order quadrature rules by default.
    let quad_rule = match mesh.element_type() {
        ElementType::P1 => QuadRule::triangle(2),
        ElementType::Q1 => QuadRule::quadrilateral(2),
    };

    let n: usize = ref_element.num_nodes();
    for element in mesh.elements() {
        // Get the coordinates of the element nodes
        let mut nodes: Vec<Point2<f64>> = Vec::with_capacity(n);
        for vid in &element.indices {
            let vertex = mesh.vertices()[*vid];
            nodes.push(vertex);
        }

        // Compute the local stiff and load vectors
        let mut ke = vec![vec![0.0; n]; n];
        let mut fe = vec![0.0; n];
        for (quad_points, quad_weights) in quad_rule.points.iter().zip(quad_rule.weights.iter()) {
            // Compute local quantities in the reference element
            let grads_ref = ref_element.shape_gradients(quad_points);
            let jac_ref = ref_element.jacobian(&nodes, quad_points);
            let det_jac_ref = jac_ref.determinant();
            let jac_inv_t = jac_ref.try_inverse().unwrap().transpose();

            // Compute gradient in the physical space
            let mut grads_global: Vec<Vector2<f64>> = Vec::with_capacity(n);
            for grad_ref in grads_ref {
                let grad = jac_inv_t * grad_ref;
                grads_global.push(grad);
            }

            // Evaluate physical coordinates of quadrature point
            let shape_vals = ref_element.shape_functions(quad_points);
            let mut x = 0.0;
            let mut y = 0.0;
            for (val, vtx) in shape_vals.iter().zip(&nodes) {
                x += val * vtx.x;
                y += val * vtx.y;
            }

            // Fill ke and fe
            let f_val = source_fn(x, y);
            let weight = quad_weights * det_jac_ref.abs();
            for i in 0..n {
                for j in 0..n {
                    ke[i][j] += grads_global[i].dot(&grads_global[j]) * weight;
                }
                fe[i] += shape_vals[i] * f_val * weight;
            }
        }

        for (local_i, &global_i) in element.indices.iter().enumerate() {
            b[global_i] += fe[local_i];
            for (local_j, &global_j) in element.indices.iter().enumerate() {
                let stiffness: f64 = ke[local_i][local_j];
                coo.push(global_i, global_j, stiffness);
            }
        }
    }

    let a = CsrMatrix::from(&coo);
    (a, b)
}
}

Applying Dirichlet boundary conditions (dense)

Dirichlet boundary conditions are enforced by:

  1. Modifying the right-hand side vector to account for known values at boundary nodes.
  2. Zeroing out the corresponding rows and columns in the stiffness matrix.
  3. Setting diagonal entries to 1 and RHS entries to the Dirichlet values.
#![allow(unused)]
fn main() {
pub fn apply_dirichlet_dense<G>(
    a: &mut DMatrix<f64>,
    b: &mut DVector<f64>,
    boundary_nodes: &[usize],
    mesh: &Mesh2d,
    g: G,
) where
    G: Fn(f64, f64) -> f64,
{
    // Compute the boundary conditions values at each boundary node
    let mut values = Vec::with_capacity(boundary_nodes.len());
    for &i in boundary_nodes {
        let v = &mesh.vertices()[i];
        values.push((i, g(v.x, v.y)));
    }
    let n = a.nrows();

    // For each boundary node j, update rhs: b_i -= a_ij * g_j for all i
    for &(j, g_j) in &values {
        for i in 0..n {
            b[i] -= a[(i, j)] * g_j;
        }
    }

    // Zero out rows and columns and set diagonal
    for &(j, g_j) in &values {
        for k in 0..n {
            a[(j, k)] = 0.0;
            a[(k, j)] = 0.0;
        }
        a[(j, j)] = 1.0;
        b[j] = g_j;
    }
}

Applying Dirichlet boundary conditions (sparse)

The sparse version performs similar operations, but with care to work directly with the CSR matrix structure.
We iterate over rows and selectively zero out entries, preserving the sparse layout.

#![allow(unused)]
fn main() {
pub fn apply_dirichlet_sparse<G>(
    a: &mut CsrMatrix<f64>,
    b: &mut DVector<f64>,
    boundary_nodes: &[usize],
    mesh: &Mesh2d,
    g: G,
) where
    G: Fn(f64, f64) -> f64,
{
    // Compute the boundary conditions values at each boundary node
    let mut bc_vals = Vec::with_capacity(boundary_nodes.len());
    for &j in boundary_nodes {
        let v = &mesh.vertices()[j];
        bc_vals.push((j, g(v.x, v.y)));
    }

    let n = a.nrows();

    for &(j, g_j) in &bc_vals {
        // For each boundary node j, update rhs: b_i -= a_ij * g_j for all i

        for i in 0..n {
            let row = a.row(i);
            let cols = row.col_indices();
            let vals = row.values();
            if let Some(pos) = cols.iter().position(|&c| c == j) {
                b[i] -= vals[pos] * g_j;
            }
        }

        // Zero out row j
        for v in a.row_mut(j).values_mut() {
            *v = 0.0;
        }

        // Zero out column j to preserve symmetry
        // We first collect the positions
        let mut to_zero: Vec<(usize, usize)> = Vec::new();
        for i in 0..n {
            let row_i = a.row(i);
            let cols = row_i.col_indices();
            if let Some(pos) = cols.iter().position(|&c| c == j) {
                to_zero.push((i, pos));
            }
        }
        // Zero out the collected positions
        for (i, pos) in to_zero {
            a.row_mut(i).values_mut()[pos] = 0.0;
        }

        // Set diagonal to 1.0
        if let Some(pos) = a.row(j).col_indices().iter().position(|&c| c == j) {
            a.row_mut(j).values_mut()[pos] = 1.0;
        }

        b[j] = g_j;
    }
}
}

Dense solver

The dense solver uses a Cholesky factorization of the symmetric positive-definite stiffness matrix to compute the solution efficiently.

#![allow(unused)]
fn main() {
pub fn dense_solver(a: &DMatrix<f64>, b: &DVector<f64>) -> Option<DVector<f64>> {
    let chol = a.clone().cholesky()?;
    Some(chol.solve(b))
}
}

Sparse solver

The sparse solver uses an iterative Conjugate Gradient (CG) method to solve the system, which is memory-efficient and scales better for large meshes.

#![allow(unused)]
fn main() {
pub fn sparse_solver(a: &CsrMatrix<f64>, b: &DVector<f64>) -> Option<DVector<f64>> {
    conjugate_gradient::solve(a, b, 1000, 1e-10)
}
}

High-level assemble-and-solve (dense)

This function combines the assembly, boundary condition application, and solve phases into a single call for dense systems.

#![allow(unused)]
fn main() {
pub fn assemble_and_solve_dense<F>(
    mesh: &Mesh2d,
    boundary_nodes: &[usize],
    boundary_fn: F,
    source_fn: F,
) -> DVector<f64>
where
    F: Fn(f64, f64) -> f64,
{
    // Assemble dense system
    let (mut a, mut b) = assemble_system_dense(mesh, &source_fn);

    // Apply BCs
    apply_dirichlet_dense(&mut a, &mut b, boundary_nodes, mesh, boundary_fn);

    // Solve linear system
    dense_solver(&a, &b).expect("failed to solve")
}
}

High-level assemble-and-solve (sparse)

Similarly, this high-level function handles all the steps for sparse systems in one call.

#![allow(unused)]
fn main() {
pub fn assemble_and_solve_sparse<F>(
    mesh: &Mesh2d,
    boundary_nodes: &[usize],
    boundary_fn: F,
    source_fn: F,
) -> DVector<f64>
where
    F: Fn(f64, f64) -> f64,
{
    // Assemble sparse system
    let (mut a, mut b) = assemble_system_sparse(mesh, &source_fn);

    // Apply BCs
    apply_dirichlet_sparse(&mut a, &mut b, boundary_nodes, mesh, boundary_fn);

    // Solve linear system
    sparse_solver(&a, &b).expect("failed to solve")
}
}

Unit tests

The tests check that both the dense and sparse assembly functions produce systems of the expected size for a simple 2x2 square mesh.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;
    use crate::element::Element;

    #[test]
    fn test_assemble_system_dense() {
        let vertices = vec![
            Point2::new(0.0, 0.0),
            Point2::new(1.0, 0.0),
            Point2::new(1.0, 1.0),
            Point2::new(0.0, 1.0),
        ];
        let elements = vec![Element {
            indices: vec![0, 1, 2, 3],
        }];
        let mesh = Mesh2d::new(vertices, elements, ElementType::Q1);

        let source_fn = |x: f64, y: f64| x + y;
        let (a, b) = assemble_system_dense(&mesh, &source_fn);

        assert_eq!(a.nrows(), 4);
        assert_eq!(b.len(), 4);
    }

    #[test]
    fn test_assemble_system_sparse() {
        let vertices = vec![
            Point2::new(0.0, 0.0),
            Point2::new(1.0, 0.0),
            Point2::new(1.0, 1.0),
            Point2::new(0.0, 1.0),
        ];
        let elements = vec![Element {
            indices: vec![0, 1, 2, 3],
        }];
        let mesh = Mesh2d::new(vertices, elements, ElementType::Q1);

        let source_fn = |x: f64, y: f64| x + y;
        let (a, b) = assemble_system_sparse(&mesh, &source_fn);

        assert_eq!(a.nrows(), 4);
        assert_eq!(b.len(), 4);
    }
}
}