Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Solver module

This module contains all the core logic required to assemble and solve the finite element system for the 2D Poisson equation.
It supports both dense and sparse matrix formulations and includes routines for:

  • Assembling the stiffness matrix and load vector from a given mesh and source function.
  • Applying Dirichlet boundary conditions to either dense or sparse systems.
  • Solving the resulting linear system using either a Cholesky decomposition (dense) or the Conjugate Gradient method (sparse).
  • High-level “assemble-and-solve” functions for convenience.

We proceed in the same order as the implementation, explaining each part in detail.

Mathematical background

We consider the Poisson problem in its weak form:

where:

  • is the computational domain.
  • is the trial space satisfying Dirichlet boundary conditions.
  • is the test space with homogeneous Dirichlet conditions.
  • is the source term.

In the finite element method, the solution is expanded in terms of basis functions , and we obtain the linear system:

The stiffness matrix and load vector are computed by assembling contributions from each element in the mesh. Numerical integration is carried out using quadrature rules on a reference element, with transformations to the physical element through the Jacobian.

Dense system assembly

The first function, assemble_system_dense, constructs the stiffness matrix and right-hand side vector for the given mesh and source term using a dense matrix representation.
It:

  1. Selects the appropriate reference element (Tri3 or Quad4) based on the mesh element type.
  2. Chooses a second-order quadrature rule to integrate element-level matrices.
  3. Loops over each element, computes local stiffness ke and local load fe, and assembles them into the global system.
#![allow(unused)]
fn main() {
pub fn assemble_system_dense<F>(mesh: &Mesh2d, source_fn: &F) -> (DMatrix<f64>, DVector<f64>)
where
    F: Fn(f64, f64) -> f64,
{
    let num_vertices = mesh.vertices().len();
    let mut a = DMatrix::zeros(num_vertices, num_vertices);
    let mut b = DVector::zeros(num_vertices);

    // Pick the right reference element based on the element type in the mesh.
    let ref_element = match mesh.element_type() {
        ElementType::P1 => ReferenceElement::Tri3,
        ElementType::Q1 => ReferenceElement::Quad4,
    };

    // Pick the right quadrature rule based on the element type in the mesh.
    // We use second-order quadrature rules by default.
    let quad_rule = match mesh.element_type() {
        ElementType::P1 => QuadRule::triangle(2),
        ElementType::Q1 => QuadRule::quadrilateral(2),
    };

    let n: usize = ref_element.num_nodes();
    for element in mesh.elements() {
        // Get the coordinates of the element nodes
        let mut nodes: Vec<Point2<f64>> = Vec::with_capacity(n);
        for vid in &element.indices {
            let vertex = mesh.vertices()[*vid];
            nodes.push(vertex);
        }

        // Compute the local stiff and load vectors
        let mut ke = vec![vec![0.0; n]; n];
        let mut fe = vec![0.0; n];
        for (quad_points, quad_weights) in quad_rule.points.iter().zip(quad_rule.weights.iter()) {
            // Compute local quantities in the reference element
            let grads_ref = ref_element.shape_gradients(quad_points);
            let jac_ref = ref_element.jacobian(&nodes, quad_points);
            let det_jac_ref = jac_ref.determinant();
            let jac_inv_t = jac_ref.try_inverse().unwrap().transpose();

            // Compute gradient in the physical space
            let mut grads_global: Vec<Vector2<f64>> = Vec::with_capacity(n);
            for grad_ref in grads_ref {
                let grad = jac_inv_t * grad_ref;
                grads_global.push(grad);
            }

            // Evaluate physical coordinates of quadrature point
            let shape_vals = ref_element.shape_functions(quad_points);
            let mut x = 0.0;
            let mut y = 0.0;
            for (val, vtx) in shape_vals.iter().zip(&nodes) {
                x += val * vtx.x;
                y += val * vtx.y;
            }

            // Fill ke and fe
            let f_val = source_fn(x, y);
            let weight = quad_weights * det_jac_ref.abs();
            for i in 0..n {
                for j in 0..n {
                    ke[i][j] += grads_global[i].dot(&grads_global[j]) * weight;
                }
                fe[i] += shape_vals[i] * f_val * weight;
            }
        }

        // Assemble into global matrix/vector
        for (local_i, &global_i) in element.indices.iter().enumerate() {
            b[global_i] += fe[local_i];
            for (local_j, &global_j) in element.indices.iter().enumerate() {
                a[(global_i, global_j)] += ke[local_i][local_j];
            }
        }
    }

    (a, b)
}
}

Sparse system assembly

The assemble_system_sparse function is analogous to the dense version, but uses a COO format during assembly and converts to CSR format for efficiency.
This is better suited for large problems where most of the global matrix entries are zero.

#![allow(unused)]
fn main() {
pub fn assemble_system_sparse<F>(mesh: &Mesh2d, source_fn: &F) -> (CsrMatrix<f64>, DVector<f64>)
where
    F: Fn(f64, f64) -> f64,
{
    let num_vertices = mesh.vertices().len();
    let mut coo = CooMatrix::new(num_vertices, num_vertices);
    let mut b = DVector::zeros(num_vertices);

    // Pick the right reference element based on the element type in the mesh.
    let ref_element = match mesh.element_type() {
        ElementType::P1 => ReferenceElement::Tri3,
        ElementType::Q1 => ReferenceElement::Quad4,
    };

    // Pick the right quadrature rule based on the element type in the mesh.
    // We use second-order quadrature rules by default.
    let quad_rule = match mesh.element_type() {
        ElementType::P1 => QuadRule::triangle(2),
        ElementType::Q1 => QuadRule::quadrilateral(2),
    };

    let n: usize = ref_element.num_nodes();
    for element in mesh.elements() {
        // Get the coordinates of the element nodes
        let mut nodes: Vec<Point2<f64>> = Vec::with_capacity(n);
        for vid in &element.indices {
            let vertex = mesh.vertices()[*vid];
            nodes.push(vertex);
        }

        // Compute the local stiff and load vectors
        let mut ke = vec![vec![0.0; n]; n];
        let mut fe = vec![0.0; n];
        for (quad_points, quad_weights) in quad_rule.points.iter().zip(quad_rule.weights.iter()) {
            // Compute local quantities in the reference element
            let grads_ref = ref_element.shape_gradients(quad_points);
            let jac_ref = ref_element.jacobian(&nodes, quad_points);
            let det_jac_ref = jac_ref.determinant();
            let jac_inv_t = jac_ref.try_inverse().unwrap().transpose();

            // Compute gradient in the physical space
            let mut grads_global: Vec<Vector2<f64>> = Vec::with_capacity(n);
            for grad_ref in grads_ref {
                let grad = jac_inv_t * grad_ref;
                grads_global.push(grad);
            }

            // Evaluate physical coordinates of quadrature point
            let shape_vals = ref_element.shape_functions(quad_points);
            let mut x = 0.0;
            let mut y = 0.0;
            for (val, vtx) in shape_vals.iter().zip(&nodes) {
                x += val * vtx.x;
                y += val * vtx.y;
            }

            // Fill ke and fe
            let f_val = source_fn(x, y);
            let weight = quad_weights * det_jac_ref.abs();
            for i in 0..n {
                for j in 0..n {
                    ke[i][j] += grads_global[i].dot(&grads_global[j]) * weight;
                }
                fe[i] += shape_vals[i] * f_val * weight;
            }
        }

        for (local_i, &global_i) in element.indices.iter().enumerate() {
            b[global_i] += fe[local_i];
            for (local_j, &global_j) in element.indices.iter().enumerate() {
                let stiffness: f64 = ke[local_i][local_j];
                coo.push(global_i, global_j, stiffness);
            }
        }
    }

    let a = CsrMatrix::from(&coo);
    (a, b)
}
}

Applying Dirichlet boundary conditions (dense)

Dirichlet boundary conditions are enforced by:

  1. Modifying the right-hand side vector to account for known values at boundary nodes.
  2. Zeroing out the corresponding rows and columns in the stiffness matrix.
  3. Setting diagonal entries to 1 and RHS entries to the Dirichlet values.
#![allow(unused)]
fn main() {
pub fn apply_dirichlet_dense<G>(
    a: &mut DMatrix<f64>,
    b: &mut DVector<f64>,
    boundary_nodes: &[usize],
    mesh: &Mesh2d,
    g: G,
) where
    G: Fn(f64, f64) -> f64,
{
    // Compute the boundary conditions values at each boundary node
    let mut values = Vec::with_capacity(boundary_nodes.len());
    for &i in boundary_nodes {
        let v = &mesh.vertices()[i];
        values.push((i, g(v.x, v.y)));
    }
    let n = a.nrows();

    // For each boundary node j, update rhs: b_i -= a_ij * g_j for all i
    for &(j, g_j) in &values {
        for i in 0..n {
            b[i] -= a[(i, j)] * g_j;
        }
    }

    // Zero out rows and columns and set diagonal
    for &(j, g_j) in &values {
        for k in 0..n {
            a[(j, k)] = 0.0;
            a[(k, j)] = 0.0;
        }
        a[(j, j)] = 1.0;
        b[j] = g_j;
    }
}

Applying Dirichlet boundary conditions (sparse)

The sparse version performs similar operations, but with care to work directly with the CSR matrix structure.
We iterate over rows and selectively zero out entries, preserving the sparse layout.

#![allow(unused)]
fn main() {
pub fn apply_dirichlet_sparse<G>(
    a: &mut CsrMatrix<f64>,
    b: &mut DVector<f64>,
    boundary_nodes: &[usize],
    mesh: &Mesh2d,
    g: G,
) where
    G: Fn(f64, f64) -> f64,
{
    // Compute the boundary conditions values at each boundary node
    let mut bc_vals = Vec::with_capacity(boundary_nodes.len());
    for &j in boundary_nodes {
        let v = &mesh.vertices()[j];
        bc_vals.push((j, g(v.x, v.y)));
    }

    let n = a.nrows();

    for &(j, g_j) in &bc_vals {
        // For each boundary node j, update rhs: b_i -= a_ij * g_j for all i

        for i in 0..n {
            let row = a.row(i);
            let cols = row.col_indices();
            let vals = row.values();
            if let Some(pos) = cols.iter().position(|&c| c == j) {
                b[i] -= vals[pos] * g_j;
            }
        }

        // Zero out row j
        for v in a.row_mut(j).values_mut() {
            *v = 0.0;
        }

        // Zero out column j to preserve symmetry
        // We first collect the positions
        let mut to_zero: Vec<(usize, usize)> = Vec::new();
        for i in 0..n {
            let row_i = a.row(i);
            let cols = row_i.col_indices();
            if let Some(pos) = cols.iter().position(|&c| c == j) {
                to_zero.push((i, pos));
            }
        }
        // Zero out the collected positions
        for (i, pos) in to_zero {
            a.row_mut(i).values_mut()[pos] = 0.0;
        }

        // Set diagonal to 1.0
        if let Some(pos) = a.row(j).col_indices().iter().position(|&c| c == j) {
            a.row_mut(j).values_mut()[pos] = 1.0;
        }

        b[j] = g_j;
    }
}
}

Dense solver

The dense solver uses a Cholesky factorization of the symmetric positive-definite stiffness matrix to compute the solution efficiently.

#![allow(unused)]
fn main() {
pub fn dense_solver(a: &DMatrix<f64>, b: &DVector<f64>) -> Option<DVector<f64>> {
    let chol = a.clone().cholesky()?;
    Some(chol.solve(b))
}
}

Sparse solver

The sparse solver uses an iterative Conjugate Gradient (CG) method to solve the system, which is memory-efficient and scales better for large meshes.

#![allow(unused)]
fn main() {
pub fn sparse_solver(a: &CsrMatrix<f64>, b: &DVector<f64>) -> Option<DVector<f64>> {
    conjugate_gradient::solve(a, b, 1000, 1e-10)
}
}

High-level assemble-and-solve (dense)

This function combines the assembly, boundary condition application, and solve phases into a single call for dense systems.

#![allow(unused)]
fn main() {
pub fn assemble_and_solve_dense<F>(
    mesh: &Mesh2d,
    boundary_nodes: &[usize],
    boundary_fn: F,
    source_fn: F,
) -> DVector<f64>
where
    F: Fn(f64, f64) -> f64,
{
    // Assemble dense system
    let (mut a, mut b) = assemble_system_dense(mesh, &source_fn);

    // Apply BCs
    apply_dirichlet_dense(&mut a, &mut b, boundary_nodes, mesh, boundary_fn);

    // Solve linear system
    dense_solver(&a, &b).expect("failed to solve")
}
}

High-level assemble-and-solve (sparse)

Similarly, this high-level function handles all the steps for sparse systems in one call.

#![allow(unused)]
fn main() {
pub fn assemble_and_solve_sparse<F>(
    mesh: &Mesh2d,
    boundary_nodes: &[usize],
    boundary_fn: F,
    source_fn: F,
) -> DVector<f64>
where
    F: Fn(f64, f64) -> f64,
{
    // Assemble sparse system
    let (mut a, mut b) = assemble_system_sparse(mesh, &source_fn);

    // Apply BCs
    apply_dirichlet_sparse(&mut a, &mut b, boundary_nodes, mesh, boundary_fn);

    // Solve linear system
    sparse_solver(&a, &b).expect("failed to solve")
}
}

Unit tests

The tests check that both the dense and sparse assembly functions produce systems of the expected size for a simple 2x2 square mesh.

#![allow(unused)]
fn main() {
#[cfg(test)]
mod tests {
    use super::*;
    use crate::element::Element;

    #[test]
    fn test_assemble_system_dense() {
        let vertices = vec![
            Point2::new(0.0, 0.0),
            Point2::new(1.0, 0.0),
            Point2::new(1.0, 1.0),
            Point2::new(0.0, 1.0),
        ];
        let elements = vec![Element {
            indices: vec![0, 1, 2, 3],
        }];
        let mesh = Mesh2d::new(vertices, elements, ElementType::Q1);

        let source_fn = |x: f64, y: f64| x + y;
        let (a, b) = assemble_system_dense(&mesh, &source_fn);

        assert_eq!(a.nrows(), 4);
        assert_eq!(b.len(), 4);
    }

    #[test]
    fn test_assemble_system_sparse() {
        let vertices = vec![
            Point2::new(0.0, 0.0),
            Point2::new(1.0, 0.0),
            Point2::new(1.0, 1.0),
            Point2::new(0.0, 1.0),
        ];
        let elements = vec![Element {
            indices: vec![0, 1, 2, 3],
        }];
        let mesh = Mesh2d::new(vertices, elements, ElementType::Q1);

        let source_fn = |x: f64, y: f64| x + y;
        let (a, b) = assemble_system_sparse(&mesh, &source_fn);

        assert_eq!(a.nrows(), 4);
        assert_eq!(b.len(), 4);
    }
}
}