Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Closed-form solution with ndarray, Option, and error handling

This section introduces several important features of the language:

  • Using the ndarray crate for numerical arrays
  • Representing optional values with Option, Some, and None
  • Using pattern matching with match
  • Handling errors using Box<dyn Error> and .into()
  • Automatically deriving common trait implementations using #[derive(...)]

Motivation

In previous sections, we worked with Vec<f64> and returned plain values. In practice, we might need:

  • Efficient linear algebra tools, provided by external crates such as ndarray and nalgebra
  • A way to represent "fitted" or "not fitted" states, using Option<f64>
  • A way to return errors when something goes wrong, using Result<_, Box<dyn Error>>
  • Automatically implementing traits like Debug, Clone, and Default to simplify testing, debugging, and construction

We combine these in the implementation of the analytical RidgeEstimator.

The full code

#![allow(unused)]
fn main() {
use ndarray::Array1;
use std::error::Error;

/// A Ridge regression estimator using `ndarray` for vectorized operations.
///
/// This version supports fitting and predicting using `Array1<f64>` arrays.
/// The coefficient `beta` is stored as an `Option<f64>`, allowing the model
/// to represent both fitted and unfitted states.
// ANCHOR: struct
#[derive(Debug, Clone, Default)]
pub struct RidgeEstimator {
    beta: Option<f64>,
}
// ANCHOR_END: struct

// ANCHOR: ridge_estimator_impl_new_fit
impl RidgeEstimator {
    /// Creates a new, unfitted Ridge estimator.
    ///
    /// # Returns
    /// A `RidgeEstimator` with `beta` set to `None`.
    pub fn new() -> Self {
        Self { beta: None }
    }

    /// Fits the Ridge regression model using 1D input and output arrays.
    ///
    /// This function computes the coefficient `beta` using the closed-form
    /// solution with L2 regularization.
    ///
    /// # Arguments
    /// - `x`: Input features as a 1D `Array1<f64>`.
    /// - `y`: Target values as a 1D `Array1<f64>`.
    /// - `lambda2`: The regularization strength (λ²).
    pub fn fit(&mut self, x: &Array1<f64>, y: &Array1<f64>, lambda2: f64) {
        let n: usize = x.len();
        assert!(n > 0);
        assert_eq!(x.len(), y.len(), "x and y must have the same length");

        // mean returns None if the array is empty, so we need to unwrap it
        let x_mean: f64 = x.mean().unwrap();
        let y_mean: f64 = y.mean().unwrap();

        let num: f64 = (x - x_mean).dot(&(y - y_mean));
        let denom: f64 = (x - x_mean).mapv(|z| z.powi(2)).sum() + lambda2 * (n as f64);

        self.beta = Some(num / denom);
    }
}
// ANCHOR_END: ridge_estimator_impl_new_fit

// ANCHOR: ridge_estimator_impl_predict
impl RidgeEstimator {
    /// Predicts target values given input features.
    ///
    /// # Arguments
    /// - `x`: Input features as a 1D array.
    ///
    /// # Returns
    /// A `Result` containing the predicted values, or an error if the model
    /// has not been fitted.
    pub fn predict(&self, x: &Array1<f64>) -> Result<Array1<f64>, Box<dyn Error>> {
        match &self.beta {
            Some(beta) => Ok(*beta * x),
            None => Err("Model not fitted".into()),
        }
    }
}
// ANCHOR_END: ridge_estimator_impl_predict
}

1. ndarray instead of Vec<f64>

Rust’s standard library does not include built-in numerical computing tools. The ndarray crate provides efficient n-dimensional arrays and vectorized operations.

This example uses:

  • Array1<f64> for 1D arrays
  • .mean(), .dot(), and .mapv() for basic mathematical operations
  • Broadcasting (x * beta) for scalar–array multiplication

2. Representing model state with Option<f64>

The model's coefficient beta is only available after fitting. To represent this, we use:

#![allow(unused)]
fn main() {
beta: Option<f64>
}

This means beta can be either:

  • Some(value): if the model is trained
  • None: if the model has not been fitted yet

This eliminates the possibility of using an uninitialized value.

3. Pattern matching with match

To check whether the model has been fitted, we use pattern matching:

#![allow(unused)]
fn main() {
match self.beta {
    Some(beta) => Ok(x * beta),
    None => Err("Model not fitted".into()),
}
}

Pattern matching ensures that all possible cases of the Option type are handled explicitly. In this case, the prediction will only be computed if beta is not None, and an error is thrown otherwise.

The error handling is explain hereafter.

4. Error handling with Box<dyn Error> and .into()

Rust requires functions to return a single concrete error type. In practice, this can be achieved in several ways. Here we use a trait object:

#![allow(unused)]
fn main() {
Result<Array1<f64>, Box<dyn Error>>
}

If the function succeeds, it must return a Array1<f64>.

If it doesn't succeed, we allow the function to return any error type that implements the Error trait. The .into() method converts a string literal into a Box<dyn Error>. Internally, Rust converts:

#![allow(unused)]
fn main() {
"Model not fitted"
}

into:

#![allow(unused)]
fn main() {
Box::new(String::from("Model not fitted"))
}

It is worth emphasizing that Box<dyn Error> means that the error is heap-allocated.

5. Using #[derive(...)] for common traits

Rust allows us to automatically implement certain traits using the #[derive(...)] attribute. In this example, we write:

#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Default)]
pub struct RidgeEstimator {
    beta: Option<f64>,
}
}

This provides the following implementations:

  • Debug: Enables printing the struct with {:?}, useful for debugging.
  • Clone: Allows duplicating the struct with .clone().
  • Default: Provides a default constructor (RidgeEstimator::default()), which sets beta to None.

By deriving these traits, we avoid writing repetitive code and ensure that the model is compatible with common Rust conventions, such as default initialization and copy-on-write semantics.

Advanced error handling

We could have gone even further by defining a custom ModelError type as follows.

#![allow(unused)]
fn main() {
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ModelError {
    #[error("Model is not fitted yet")]
    NotFitted,
    #[error("Dimension mismatch")]
    DimensionMismatch,
}
}

This approach uses the thiserror crate to simplify the implementation of the standard Error trait.

By deriving #[derive(Debug, Error)] and annotating each variant with #[error("...")], we define error messages rightaway.

The predict function would be rewritten as:

#![allow(unused)]
fn main() {
pub fn predict(&self, x: &Array1<f64>) -> Result<f64, ModelError> {
    match &self.beta {
        Some(beta) => {
            if beta.len() != x.len() {
                return Err(ModelError::DimensionMismatch);
            }
            Ok(beta.dot(x))
        }
        None => Err(ModelError::NotFitted),
    }
}
}