Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Predict function

This section describes the implementation of the predict method, which uses the trained model to make predictions on new inputs. The full code is given below, and we break it down in the sequel of this section.

#![allow(unused)]
fn main() {
impl<K: Kernel> KRRModel<K> {
    pub fn predict(&self, x_test: &Array2<f64>) -> Result<Array1<f64>, KRRPredictError> {
        let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
        let x_train = self.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;

        let n_train: usize = x_train.nrows();
        let n_test: usize = x_test.nrows();
        let mut y_pred: Array1<f64> = Array::zeros(n_test);

        for i in 0..n_test {
            for j in 0..n_train {
                let k_val = self.kernel.compute(x_train.row(j), x_test.row(i));
                y_pred[i] += alpha[j] * k_val;
            }
        }
        Ok(y_pred)
    }
}
}

It takes a reference to a two-dimensional test array and returns either the predicted values as a one-dimensional array or an error if the model is not yet fitted.

Logic of the predict method

The function first checks whether the model has been fitted. This involves verifying that the fields self.alpha and self.x_train are set (i.e., not None). If either is missing, the function returns a KRRPredictError::NotFitted variant.

#![allow(unused)]
fn main() {
let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
let x_train = self.x_train.as_ref().ok_or(KRRPredictError::NotFitted)?;
}

The actual prediction then proceeds by computing the kernel value between each training point and each test point:

#![allow(unused)]
fn main() {
for i in 0..n_test {
    for j in 0..n_train {
        let k_val = self.kernel.compute(x_train.row(j), x_test.row(i));
        y_pred[i] += alpha[j] * k_val;
    }
}
}

This implements the inference equation:

where are the training samples, are the learned dual coefficients, and is the kernel function.

KRRPredictError

The KRRPredictError is an enum used to indicate that the model has not been fitted yet. It is defined in the errors.rs module using the thiserror crate:

#![allow(unused)]
fn main() {
#[derive(Debug, Error)]
pub enum KRRPredictError {
    #[error("Model not fitted")]
    NotFitted,
}
}

This enum allows the predict function to return a Result type, making error propagation idiomatic and clean.

Use of ? operator

The function uses the ? operator to simplify error handling. For example:

#![allow(unused)]
fn main() {
let alpha = self.alpha.as_ref().ok_or(KRRPredictError::NotFitted)?;
}

This line either extracts the alpha reference if it exists or returns early with an error. This pattern keeps the code concise and expressive.

Unit tests

The test_unfitted_predict_error_type unit test checks that the correct error is returned when attempting to call predict before fitting the model:

#![allow(unused)]
fn main() {
#[test]
fn test_unfitted_predict_error_type() {
    use crate::errors::KRRPredictError;

    let kernel = RBFKernel::new(1.0);
    let model: KRRModel<RBFKernel> = KRRModel::new(kernel, 1.0);
    let x_test: Array2<f64> = array![[1.0, 2.0, 3.0]];

    let result = model.predict(&x_test);
    assert!(matches!(result, Err(KRRPredictError::NotFitted)));
}
}