basis_matrix()

Contents

basis_matrix()#

liesel.contrib.splines.basis_matrix(x, knots, order=3, outer_ok=False)[source]#

Builds a B-spline basis matrix.

Parameters:
  • x (Any) – A 1d array of input data.

  • knots (Any) – A 1d array of knots. The knots will be sorted.

  • order (int) – A positive integer giving the order of the spline function. A cubic spline has an order of 3. (default: 3)

  • outer_ok (bool) – If False (default), values of x outside the range of interior knots cause an error. If True, they are allowed. (default: False)

Return type:

Any

Returns:

A 2d array, the B-spline basis matrix.

Notes

Under the hood, instead of applying the recursive definition of B-splines, a matrix of (order + 1) rows and (dim(knots) - order - 1) columns for each value in x is created. This matrix store the evaluation of the basis function at the observed value for the m-th order and for the i-th knot.

Jit-compilation

The basis_matrix function internally uses a jit-compiled function to do the heavy lifting. However, you may want to make basis_matrix itself jit-compilable. In this case, you need to define the arguments order and outer_ok as static arguments. Further, outer_ok needs to be fixed to True.

If you just want to set up a basis matrix once, it is usually not necessary to go through this process.

Example:

from liesel.contrib.splines import equidistant_knots, basis_matrix

x = jnp.linspace(-2.0, 2.0, 30)
knots = equidistant_knots(x, n_param=10, order=3)

basis_matrix_jit = jax.jit(basis_matrix, static_argnames=("order", "outer_ok"))

B = basis_matrix_jit(x, knots, order, outer_ok=True)

Another suitable way to go is to use functools.partial:

from functools import partial
from liesel.contrib.splines import equidistant_knots, basis_matrix

x = jnp.linspace(-2.0, 2.0, 30)
knots = equidistant_knots(x, n_param=10, order=3)

basis_matrix_fix = partial(basis_matrix, order=3, outer_ok=True)
basis_matrix_jit = jax.jit(basis_matrix_fix)

B = basis_matrix_jit(x, knots)