"""
Basic functionality for using B-splines in Liesel.
"""
from functools import partial
import jax.numpy as jnp
from jax import jit, lax, vmap
from liesel.model import Array
[docs]
def equidistant_knots(
x: Array, n_param: int, order: int = 3, eps: float = 0.01
) -> Array:
"""
Create equidistant knots for a B-spline of the specified order.
Parameters
----------
x
A 1d array of input data.
order
A positive integer giving the order of the spline function.
A cubic spline has an order of 3.
n_param
Number of parameters of the B-spline.
eps
A factor by which the range of the interior knots is stretched. The range of
interior knots will thus be ``eps * (jnp.max(x) - jnp.min(x))``.
Returns
-------
A 1d array
Notes
-----
Some additional info:
- ``dim(knots) = n_params + order + 1``
- ``n_param = dim(knots) - order - 1``
- ``n_interior_knots = n_param - order + 1``
"""
if order < 0:
raise ValueError(f"Invalid {order=}.")
if n_param < order:
raise ValueError(f"{n_param=} must not be smaller than {order=}.")
n_internal_knots = n_param - order + 1
a = jnp.min(x)
b = jnp.max(x)
range_ = b - a
min_k = a - range_ * (eps / 2)
max_k = b + range_ * (eps / 2)
internal_knots = jnp.linspace(min_k, max_k, n_internal_knots)
step = internal_knots[1] - internal_knots[0]
left_knots = jnp.linspace(min_k - (step * order), min_k - step, order)
right_knots = jnp.linspace(max_k + step, max_k + (step * order), order)
knots = jnp.concatenate((left_knots, internal_knots, right_knots))
return knots
@partial(jit, static_argnames="order")
def _build_basis_vector(x: Array, knots: Array, order: int) -> Array:
"""
Builds a vector of length ``dim(knots) - order - 1``. Each entry ``i`` is
iterativaly updated. At time m, the entry i is the evaluation of the basis function
at the observed value for the m-th order and for the i-th knot. The creation of the
matrix needs a row-wise (order) loop (f1) and a column-wise (knot index) loop (f2).
"""
k = knots.shape[0] - order - 1
bv = jnp.full(knots.shape[0] - 1, jnp.nan)
def basis_per_order(m, bv):
def basis_per_knot(i, bv):
def base_case(bv):
return bv.at[i].set(
jnp.where(x >= knots[i], 1.0, 0.0)
* jnp.where(x < knots[i + 1], 1.0, 0.0)
)
def recursive_case(bv):
b1 = (x - knots[i]) / (knots[i + m] - knots[i]) * bv[i]
b2 = (
(knots[i + m + 1] - x)
/ (knots[i + m + 1] - knots[i + 1])
* bv[i + 1]
)
return bv.at[i].set(b1 + b2)
return lax.cond(m == 0, base_case, recursive_case, bv)
return lax.fori_loop(0, k + order, basis_per_knot, bv)
return lax.fori_loop(0, order + 1, basis_per_order, bv)[:k]
[docs]
def basis_matrix(
x: Array, knots: Array, order: int = 3, outer_ok: bool = False
) -> Array:
"""
Builds a B-spline basis matrix.
Parameters
----------
x
A 1d array of input data.
knots
A 1d array of knots. The knots will be sorted.
order
A positive integer giving the order of the spline function. \
A cubic spline has an order of 3.
outer_ok
If ``False`` (default), values of x outside the range of interior knots \
cause an error. If ``True``, they are allowed.
Returns
-------
A 2d array, the B-spline basis matrix.
Notes
-----
Under the hood, instead of applying the recursive
definition of B-splines, a matrix of (order + 1) rows and (dim(knots) - order - 1)
columns for each value in x is created. This matrix store the evaluation of
the basis function at the observed value for the m-th order and for the i-th knot.
.. rubric:: Jit-compilation
The ``basis_matrix`` function internally uses a jit-compiled function to do the
heavy lifting. However, you may want to make ``basis_matrix`` itself jit-compilable.
In this case, you need to define the arguments ``order`` and ``outer_ok`` as
static arguments. Further, ``outer_ok`` needs to be fixed to ``True``.
If you just want to set up a basis matrix once, it is usually not necessary to go
through this process.
Example:
.. code-block:: python
from liesel.contrib.splines import equidistant_knots, basis_matrix
x = jnp.linspace(-2.0, 2.0, 30)
knots = equidistant_knots(x, n_param=10, order=3)
basis_matrix_jit = jax.jit(basis_matrix, static_argnames=("order", "outer_ok"))
B = basis_matrix_jit(x, knots, order, outer_ok=True)
Another suitable way to go is to use ``functools.partial``::
from functools import partial
from liesel.contrib.splines import equidistant_knots, basis_matrix
x = jnp.linspace(-2.0, 2.0, 30)
knots = equidistant_knots(x, n_param=10, order=3)
basis_matrix_fix = partial(basis_matrix, order=3, outer_ok=True)
basis_matrix_jit = jax.jit(basis_matrix_fix)
B = basis_matrix_jit(x, knots)
"""
if order < 0:
raise ValueError(f"Invalid {order=}.")
# if x is a scalar, this ensures that the function still works
x = jnp.atleast_1d(x)
knots = jnp.sort(knots)
if not outer_ok:
min_ = knots[order]
max_ = knots[knots.shape[0] - order - 1]
geq_min = jnp.min(x) >= min_
leq_max = jnp.max(x) <= max_
if not geq_min and leq_max:
raise ValueError(
f"Values of x are not inside the range of interior knots, [{min_},"
f" {max_}]"
)
design_matrix = vmap(lambda x: _build_basis_vector(x, knots, order))(x)
return design_matrix
[docs]
def pspline_penalty(d: int, diff: int = 2):
"""
Builds an (n_param x n_param) P-spline penalty matrix.
Parameters
----------
d
Integer, dimension of the matrix. Corresponds to the number of parameters \
in a P-spline.
diff
Order of the differences used in constructing the penalty matrix. The default \
of ``diff=2`` corresponds to the common P-spline default of penalizing second \
differences.
Returns
-------
A 2d array, the penalty matrix.
"""
D = jnp.diff(jnp.identity(d), diff, axis=0)
return D.T @ D