Closed-form solution with ndarray
, Option
, and error handling
This section introduces several important features of the language:
- Using the
ndarray
crate for numerical arrays - Representing optional values with
Option
,Some
, andNone
- Using pattern matching with
match
- Handling errors using
Box<dyn Error>
and.into()
- Automatically deriving common trait implementations using #[derive(...)]
Motivation
In previous sections, we worked with Vec<f64>
and returned plain values. In practice, we might need:
- Efficient linear algebra tools, provided by external crates such as
ndarray
andnalgebra
- A way to represent "fitted" or "not fitted" states, using
Option<f64>
- A way to return errors when something goes wrong, using
Result<_, Box<dyn Error>>
- Automatically implementing traits like
Debug
,Clone
, andDefault
to simplify testing, debugging, and construction
We combine these in the implementation of the analytical RidgeEstimator
.
The full code
#![allow(unused)] fn main() { use ndarray::Array1; use std::error::Error; /// A Ridge regression estimator using `ndarray` for vectorized operations. /// /// This version supports fitting and predicting using `Array1<f64>` arrays. /// The coefficient `beta` is stored as an `Option<f64>`, allowing the model /// to represent both fitted and unfitted states. // ANCHOR: struct #[derive(Debug, Clone, Default)] pub struct RidgeEstimator { beta: Option<f64>, } // ANCHOR_END: struct // ANCHOR: ridge_estimator_impl_new_fit impl RidgeEstimator { /// Creates a new, unfitted Ridge estimator. /// /// # Returns /// A `RidgeEstimator` with `beta` set to `None`. pub fn new() -> Self { Self { beta: None } } /// Fits the Ridge regression model using 1D input and output arrays. /// /// This function computes the coefficient `beta` using the closed-form /// solution with L2 regularization. /// /// # Arguments /// - `x`: Input features as a 1D `Array1<f64>`. /// - `y`: Target values as a 1D `Array1<f64>`. /// - `lambda2`: The regularization strength (λ²). pub fn fit(&mut self, x: &Array1<f64>, y: &Array1<f64>, lambda2: f64) { let n: usize = x.len(); assert!(n > 0); assert_eq!(x.len(), y.len(), "x and y must have the same length"); // mean returns None if the array is empty, so we need to unwrap it let x_mean: f64 = x.mean().unwrap(); let y_mean: f64 = y.mean().unwrap(); let num: f64 = (x - x_mean).dot(&(y - y_mean)); let denom: f64 = (x - x_mean).mapv(|z| z.powi(2)).sum() + lambda2 * (n as f64); self.beta = Some(num / denom); } } // ANCHOR_END: ridge_estimator_impl_new_fit // ANCHOR: ridge_estimator_impl_predict impl RidgeEstimator { /// Predicts target values given input features. /// /// # Arguments /// - `x`: Input features as a 1D array. /// /// # Returns /// A `Result` containing the predicted values, or an error if the model /// has not been fitted. pub fn predict(&self, x: &Array1<f64>) -> Result<Array1<f64>, Box<dyn Error>> { match &self.beta { Some(beta) => Ok(*beta * x), None => Err("Model not fitted".into()), } } } // ANCHOR_END: ridge_estimator_impl_predict }
1. ndarray
instead of Vec<f64>
Rust’s standard library does not include built-in numerical computing tools. The ndarray
crate provides efficient n-dimensional arrays and vectorized operations.
This example uses:
Array1<f64>
for 1D arrays.mean()
,.dot()
, and.mapv()
for basic mathematical operations- Broadcasting (
x * beta
) for scalar–array multiplication
2. Representing model state with Option<f64>
The model's coefficient beta
is only available after fitting. To represent this, we use:
#![allow(unused)] fn main() { beta: Option<f64> }
This means beta
can be either:
Some(value)
: if the model is trainedNone
: if the model has not been fitted yet
This eliminates the possibility of using an uninitialized value.
3. Pattern matching with match
To check whether the model has been fitted, we use pattern matching:
#![allow(unused)] fn main() { match self.beta { Some(beta) => Ok(x * beta), None => Err("Model not fitted".into()), } }
Pattern matching ensures that all possible cases of the Option
type are handled explicitly. In this case, the prediction will only be computed if beta
is not None, and an error is thrown otherwise.
The error handling is explain hereafter.
4. Error handling with Box<dyn Error>
and .into()
Rust requires functions to return a single concrete error type. In practice, this can be achieved in several ways. Here we use a trait object:
#![allow(unused)] fn main() { Result<Array1<f64>, Box<dyn Error>> }
If the function succeeds, it must return a Array1<f64>
.
If it doesn't succeed, we allow the function to return any error type that implements the Error
trait. The .into()
method converts a string literal into a Box<dyn Error>
. Internally, Rust converts:
#![allow(unused)] fn main() { "Model not fitted" }
into:
#![allow(unused)] fn main() { Box::new(String::from("Model not fitted")) }
It is worth emphasizing that Box<dyn Error>
means that the error is heap-allocated.
5. Using #[derive(...)]
for common traits
Rust allows us to automatically implement certain traits using the #[derive(...)]
attribute. In this example, we write:
#![allow(unused)] fn main() { #[derive(Debug, Clone, Default)] pub struct RidgeEstimator { beta: Option<f64>, } }
This provides the following implementations:
Debug
: Enables printing the struct with{:?}
, useful for debugging.Clone
: Allows duplicating the struct with.clone()
.Default
: Provides a default constructor (RidgeEstimator::default()
), which setsbeta
toNone
.
By deriving these traits, we avoid writing repetitive code and ensure that the model is compatible with common Rust conventions, such as default initialization and copy-on-write semantics.
We could have gone even further by defining a custom ModelError
type as follows.
#![allow(unused)] fn main() { use thiserror::Error; #[derive(Debug, Error)] pub enum ModelError { #[error("Model is not fitted yet")] NotFitted, #[error("Dimension mismatch")] DimensionMismatch, } }
This approach uses the thiserror
crate to simplify the implementation of the standard Error
trait.
By deriving #[derive(Debug, Error)]
and annotating each variant with #[error("...")]
, we define error messages rightaway.
The predict function would be rewritten as:
#![allow(unused)] fn main() { pub fn predict(&self, x: &Array1<f64>) -> Result<f64, ModelError> { match &self.beta { Some(beta) => { if beta.len() != x.len() { return Err(ModelError::DimensionMismatch); } Ok(beta.dot(x)) } None => Err(ModelError::NotFitted), } } }