Source code for liesel.model.legacy

"""
Imitates the API from v0.1.
"""

from __future__ import annotations

from typing import Any, cast

import jax.numpy as jnp

from .nodes import Bijector as TFPBijector
from .nodes import Calc, Dist, Node, Var

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Strong variables ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


def DesignMatrix(
    value: Any | Calc, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A strong variable representing a design matrix."""
    var = Var.new_obs(value, distribution, name)
    var.role = "DesignMatrix"
    return var


def Hyperparameter(
    value: Any | Calc, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A strong variable representing a hyperparameter."""
    var = Var(value, distribution, name)
    var.role = "Hyperparameter"
    return var


def Parameter(
    value: Any | Calc, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A strong variable representing a model parameter."""
    var = Var.new_param(value, distribution, name)
    var.role = "Parameter"
    return var


def RegressionCoef(
    value: Any | Calc, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A strong variable representing a vector of regression coefficients."""
    var = Var.new_param(value, distribution, name)
    var.role = "RegressionCoef"
    return var


def Response(
    value: Any | Calc, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A strong variable representing a response vector."""
    var = Var.new_obs(value, distribution, name)
    var.role = "Response"
    return var


def SmoothingParam(
    value: Any | Calc, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A strong variable representing a smoothing parameter."""
    var = Var.new_param(value, distribution, name)
    var.role = "SmoothingParam"
    return var


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Weak variables ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


def Addition(
    *inputs: Var | Node, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A weak variable calculating an element-wise sum."""
    calc = Calc(lambda *args, **kwargs: sum(args) + sum(kwargs.values()), *inputs)
    var = Var(calc, distribution, name)
    var.role = "Addition"
    return var


def Bijector(
    _input: Var | Node,
    bijector: type[TFPBijector],
    inverse: bool = False,
    distribution: Dist | None = None,
    name: str = "",
) -> Var:
    """
    A weak variable evaluating the ``forward()`` or ``inverse()`` method
    of a TFP bijector.
    """

    def fn(x):
        return bijector().forward(x) if not inverse else bijector().inverse(x)

    calc = Calc(fn, _input)
    var = Var(calc, distribution, name)
    var.role = "Bijector"
    return var


def ColumnStack(
    *inputs: Var | Node, distribution: Dist | None = None, name: str = ""
) -> Var:
    """A weak variable stacking its inputs column-wise."""
    calc = Calc(lambda *args: jnp.column_stack(args), *inputs)
    var = Var(calc, distribution, name)
    var.role = "ColumnStack"
    return var


def InverseLink(
    _input: Var | Node,
    bijector: type[TFPBijector],
    inverse: bool = False,
    distribution: Dist | None = None,
    name: str = "",
) -> Var:
    """A weak variable representing an inverse link function."""
    var = Bijector(_input, bijector, inverse, distribution, name)
    var.role = "InverseLink"
    return var


class PITCalc(Node):
    """A probability integral transform (PIT) calculator node."""

    def __init__(
        self,
        _input: Dist,
        _name: str = "",
        _needs_seed: bool = False,
    ):
        super().__init__(_input, _name=_name, _needs_seed=_needs_seed)

    def update(self) -> PITCalc:
        dist = cast(Dist, self.inputs[0])

        if not dist.at:
            raise RuntimeError(
                f"Cannot evaluate PIT on {repr(dist)}, property `at` not set"
            )

        self._value = dist.init_dist().cdf(dist.at.value)
        self._outdated = False
        return self


[docs] def PIT(_input: Var | Dist, distribution: Dist | None = None, name: str = "") -> Var: """ A weak variable evaluating a probability integral transform (PIT). See Also -------- .liesel.distribution.GaussianCopula : A bivariate Gaussian copula distribution. Examples -------- >>> import liesel.model as lsl >>> import tensorflow_probability.substrates.jax.distributions as tfd >>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0) >>> x = lsl.Var.new_obs(1.0, dist) >>> x_pit = lsl.PIT(x, name="PIT(x)").update() >>> x_pit.value.round(2) Array(0.84, dtype=float32) """ dist = _input.dist_node if isinstance(_input, Var) else _input if not dist: raise RuntimeError(f"Cannot evaluate PIT on {repr(_input)}, has no dist node") if _input.name and not name: name = f"{_input.name}_pit" calc = PITCalc(dist) var = Var(calc, distribution, name) var.role = "PIT" return var
def Predictor( *inputs: Var | Node, distribution: Dist | None = None, name: str = "" ) -> Var: """A weak variable representing a regression predictor.""" var = Addition(*inputs, distribution=distribution, name=name) var.role = "Predictor" return var def Smooth( x: Var | Node, beta: Var | Node, distribution: Dist | None = None, name: str = "" ) -> Var: """A weak variable calculating the matrix-vector product ``x @ beta``.""" calc = Calc(lambda x, beta: x @ beta, x=x, beta=beta) var = Var(calc, distribution, name) var.role = "Smooth" return var