Generic Ridge estimator
In this section, we implement a Ridge estimator that works with f32
and f64
using generics and trait bounds.
To make this work, we need to import the following:
#![allow(unused)] fn main() { use num_traits::Float; use std::iter::Sum; }
Ridge model trait
We start by defining a trait, RidgeModel
, that describes the core behavior expected from any Ridge regression model. We tell the compiler that F
must implement the traits Float
and Sum
.
#![allow(unused)] fn main() { pub trait RidgeModel<F: Float + Sum> { /// Fits the model to the given data using Ridge regression. fn fit(&mut self, x: &[F], y: &[F], lambda2: F); /// Predicts output values for a slice of new input features. fn predict(&self, x: &[F]) -> Vec<F>; } }
Ridge estimator structure
We next define a Ridge structure as usual but using our generic type F
. The model only stores the Ridge coefficient beta
.
#![allow(unused)] fn main() { pub struct GenRidgeEstimator<F: Float + Sum> { pub beta: F, } }
We implement the constructor as usual as well.
#![allow(unused)] fn main() { impl<F: Float + Sum> GenRidgeEstimator<F> { /// Creates a new estimator with the given initial beta coefficient. pub fn new(init_beta: F) -> Self { Self { beta: init_beta } } } }
Fit and predict methods
We can finally implement the trait RidgeModel
for our GenRidgeEstimator
.
#![allow(unused)] fn main() { impl<F: Float + Sum> RidgeModel<F> for GenRidgeEstimator<F> { /// Fits the Ridge regression model to 1D data using closed-form solution. /// /// This method computes the regression coefficient `beta` by minimizing /// the Ridge-regularized least squares loss. /// /// # Arguments /// - `x`: Input features. /// - `y`: Target values. /// - `lambda2`: The regularization parameter (λ²). fn fit(&mut self, x: &[F], y: &[F], lambda2: F) { let n: usize = x.len(); let n_f: F = F::from(n).unwrap(); assert_eq!(x.len(), y.len(), "x and y must have the same length"); let x_mean: F = x.iter().copied().sum::<F>() / n_f; let y_mean: F = y.iter().copied().sum::<F>() / n_f; let num: F = x .iter() .zip(y.iter()) .map(|(xi, yi)| (*xi - x_mean) * (*yi - y_mean)) .sum::<F>(); let denom: F = x.iter().map(|xi| (*xi - x_mean).powi(2)).sum::<F>() + lambda2 * n_f; self.beta = num / denom; } /// Applies the trained model to input features to generate predictions. /// /// # Arguments /// - `x`: Input features to predict from. /// /// # Returns /// A vector of predicted values, one for each input in `x`. fn predict(&self, x: &[F]) -> Vec<F> { x.iter().map(|xi| *xi * self.beta).collect() } } }
Notice that the trait bounds <F: Float + Sum> RidgeModel<F>
are defined after the name of a trait
or struct
, or right next to an impl
.
Without Sum
, the compiler does not allow .sum()
on iterators of F
. Try removing Sum
from the bound:
#![allow(unused)] fn main() { impl<F: Float> RidgeModel<F> for GenRidgeEstimator<F> }
And keep a call to .sum()
. The compiler should complain:
error[E0599]: the method `sum` exists for iterator `std::slice::Iter<'_, F>`,
but its trait bounds were not satisfied
The copied()
method in .iter().copied().sum::<F>()
is necessary because we're iterating over a slice of F
, and F
is a generic type that implements the Copy
trait but not the Clone
trait by default.
Without this, x.iter()
yields references &F
while sum::<F>()
expects owned values of type F
. We could have used cloned()
instead but since Float
already requires Copy
, this works without adding the Clone
trait bound.
Note that in the predict
function, we don't need to use Copy
because we manually dereference each item, *xi
, inside the .map()
. It could have been possible to use copied()
there as well and modify the mapping closure accordingly.
The unwrap
in F::from(n).unwrap()
. The length n
of the slice x
is of type usize
, as usual. We need n_f
to be of type F
so that we can perform operations like division with other F
-typed values.
The conversion is done using F::from(n)
which returns an Option<F>
, not a plain F
. We assumed that the conversion always succeeds or crashes by using unwrap()
. Since n
is from x.len()
, it might easily be representable as f32
or f64
, so unwrapping seems safe.
Note that we could have handled the error explicitly:
#![allow(unused)] fn main() { let n_f: F = F::from(n).expect("Length too large to convert to float"); </div> </div> }