basis_matrix()#
- liesel.contrib.splines.basis_matrix(x, knots, order=3, outer_ok=False)[source]#
Builds a B-spline basis matrix.
- Parameters:
x (
Any
) – A 1d array of input data.knots (
Any
) – A 1d array of knots. The knots will be sorted.order (
int
) – A positive integer giving the order of the spline function. A cubic spline has an order of 3. (default:3
)outer_ok (
bool
) – IfFalse
(default), values of x outside the range of interior knots cause an error. IfTrue
, they are allowed. (default:False
)
- Return type:
- Returns:
A 2d array, the B-spline basis matrix.
Notes
Under the hood, instead of applying the recursive definition of B-splines, a matrix of (order + 1) rows and (dim(knots) - order - 1) columns for each value in x is created. This matrix store the evaluation of the basis function at the observed value for the m-th order and for the i-th knot.
Jit-compilation
The
basis_matrix
function internally uses a jit-compiled function to do the heavy lifting. However, you may want to makebasis_matrix
itself jit-compilable. In this case, you need to define the argumentsorder
andouter_ok
as static arguments. Further,outer_ok
needs to be fixed toTrue
.If you just want to set up a basis matrix once, it is usually not necessary to go through this process.
Example:
from liesel.contrib.splines import equidistant_knots, basis_matrix x = jnp.linspace(-2.0, 2.0, 30) knots = equidistant_knots(x, n_param=10, order=3) basis_matrix_jit = jax.jit(basis_matrix, static_argnames=("order", "outer_ok")) B = basis_matrix_jit(x, knots, order, outer_ok=True)
Another suitable way to go is to use
functools.partial
:from functools import partial from liesel.contrib.splines import equidistant_knots, basis_matrix x = jnp.linspace(-2.0, 2.0, 30) knots = equidistant_knots(x, n_param=10, order=3) basis_matrix_fix = partial(basis_matrix, order=3, outer_ok=True) basis_matrix_jit = jax.jit(basis_matrix_fix) B = basis_matrix_jit(x, knots)