Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Hyperparameter tuning with LOO-CV

This section focuses on hyperparameter selection for tuning the kernel lengthscale using Leave-One-Out Cross-Validation (LOO-CV). We implement two key functions in the model_selection.rs module:

  • loo_cv_error: computes the LOO-CV error for a given model and training dataset.
  • tune_lengthscale: evaluates multiple candidate lengthscales and returns the one with the lowest LOO-CV error.

Leave-One-Out Cross-Validation (LOO-CV)

LOO-CV is a classical cross-validation strategy in which each data point is left out once as a test sample while the model is trained on the remaining samples. The final error is the average squared difference between the predicted and true values.

In KRR, thanks to the closed-form solution, the LOO-CV error can be computed efficiently without retraining the model times. The formula is:

Where:

  • is the hat matrix,
  • is the prediction on the training data,
  • is the i-th diagonal element of the hat matrix.

The loo_cv_error function implements this logic:

#![allow(unused)]
fn main() {
pub fn loo_cv_error<K: Kernel>(model: &KRRModel<K>) -> Result<f64, KRRPredictError> {
    let alpha = model.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
    let x_train = model.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;

    let n = x_train.nrows();
    let mut k_train = Array2::zeros((n, n));

    for i in 0..n {
        for j in 0..=i {
            let kxy = model.kernel.compute(x_train.row(i), x_train.row(j));
            k_train[(i, j)] = kxy;
            k_train[(j, i)] = kxy;
        }
    }

    let identity_n = Array2::eye(n);
    let a = k_train + model.lambda * identity_n;
    let a_inv = a.inv().expect("Inversion failed");

    let mut loo_error = 0.0;
    for i in 0..n {
        let ai = alpha[i];
        let di = a_inv[(i, i)];
        let res = ai / di;
        loo_error += res.powi(2);
    }

    Ok(loo_error / (n as f64))
}
}

It returns the mean squared error over the training set based on the LOO-CV formula.

Note that we got a bit lazy here:

  • We use the .expect() method to raise an exception if the inversion fails. This will make the code crash instead of making the function return an error, like we did with our KRRFitError and KRRPredictError enums.
  • We re-compute the Gram matrix whereas it could be stored within the KRRModel like we did for alpha and x_train.

Tuning the lengthscale

We now want to search for the optimal lengthscale of the RBF kernel that minimizes the LOO-CV error. The function:

#![allow(unused)]
fn main() {
pub fn tune_lengthscale<K: Kernel + Clone>(
    x_train: Array2<f64>,
    y_train: Array1<f64>,
    lambda: f64,
    lengthscales: &[f64],
    kernel_builder: impl Fn(f64) -> K,
) -> Result<(K, f64), String> {
    let mut best_error = f64::INFINITY;
    let mut best_kernel = None;

    for &l in lengthscales {
        let kernel = kernel_builder(l);
        let mut model = KRRModel::new(kernel.clone(), lambda);

        if model.fit(x_train.clone(), y_train.clone()).is_err() {
            continue;
        }

        if let Ok(err) = loo_cv_error(&model)
            && err < best_error
        {
            best_error = err;
            best_kernel = Some(kernel);
        }
    }

    best_kernel
        .map(|k| (k, best_error))
        .ok_or("Tuning failed".to_string())
}
}

takes in a list of candidate lengthscales, fits a model for each, and selects the one with the lowest LOO-CV error. Internally, it uses:

  1. The RBFKernel struct to instantiate kernels with varying lengthscales.
  2. The KRRModel::fit method to train each model.
  3. The loo_cv_error function to evaluate them.

Unit test

The test_tune_lengthscale test verifies that the tuning function works correctly:

#![allow(unused)]
fn main() {
#[test]
fn test_tune_lengthscale() {
    let x_train = array![[0.0], [1.0], [2.0]];
    let y_train = array![0.0, 1.0, 2.0];
    let candidates = vec![0.01, 0.1, 1.0, 10.0];

    let best = tune_lengthscale(x_train, y_train, &candidates);
    assert!(candidates.contains(&best));
}
}

This test confirms that the selected best value is indeed one of the candidate lengthscales.