from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, Protocol, assert_never
import tensorflow_probability.substrates.jax.distributions as tfd
from .builder import EngineBuilder
from .interface import LieselInterface
from .types import Array, JitterFunctions, Kernel, KeyArray
if TYPE_CHECKING:
from liesel.model import Model, Var
logger = logging.getLogger(__name__)
[docs]
@dataclass
class LieselMCMC:
"""
Manages the setup of MCMC specifications for a Liesel model.
Parameters
----------
model
The Liesel model object containing the variables and their inference \
specifications.
which
A named inference configuration to use. If None, the default inference \
attached to each variable is used.
Examples
--------
.. rubric:: Liesel Workflow
For this example, we import ``tensorflow_probability`` as follows:
>>> import tensorflow_probability.substrates.jax.distributions as tfd
First, we set up a minimal model:
>>> mu = lsl.Var.new_param(0.0, name="mu", inference=gs.MCMCSpec(gs.NUTSKernel))
>>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y")
>>> model = lsl.Model([y])
Now we initialize the EngineBuilder and set the desired number of warmup and
posterior samples:
>>> builder = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
>>> builder.add_adaptation(1000)
>>> builder.add_posterior(1000)
Finally, we build the engine:
>>> engine = builder.build()
"""
model: Model
which: str | None = None
[docs]
def get_spec(self, var: Var) -> MCMCSpec | None:
"""
Retrieve the MCMC specification for a given variable.
Parameters
----------
var
The model variable for which to get the MCMC specification.
Returns
-------
The MCMC specification if available, otherwise None.
Raises
------
ValueError
If the inference attached to the variable is not of type ``MCMCSpec``.
"""
inference = var.get_inference(self.which)
if inference is None:
return inference
if not isinstance(inference, MCMCSpec):
raise ValueError(
f"Attribute 'inference' of variable {var} is of type"
f" {type(inference)}, but expected type '{MCMCSpec}'."
)
return inference
[docs]
def get_kernel_groups(self) -> dict[str, _KernelGroup]:
"""
Collect and organize model variables into kernel groups.
Returns
-------
A dictionary mapping group names or variable names to their corresponding \
kernel group specifications.
Raises
------
ValueError
If variables in the same kernel group have inconsistent kernels or kernel \
arguments.
"""
vars_ = self.model.vars
kernel_groups: dict[str, _KernelGroup] = {}
for name, var in vars_.items():
inference = self.get_spec(var)
if not inference:
continue
group_name = inference.kernel_group
if group_name is None:
kernel_groups[name] = _KernelGroup(
kernel=inference.kernel,
kwargs=inference.kernel_kwargs,
position_keys=[name],
)
elif group_name in kernel_groups:
group = kernel_groups[group_name]
same_kernel = group.kernel is inference.kernel
if not same_kernel:
raise ValueError(
"Found incoherent kernel classes for kernel group"
f" {group_name}."
)
if inference.kernel_kwargs is None:
pass
elif not group.kwargs:
group.kwargs = inference.kernel_kwargs
else:
if group.kwargs is not inference.kernel_kwargs:
raise ValueError(
"Found incoherent kernel keyword arguments for "
f"kernel group {group_name}. "
"When supplying kernel keyword arguments for multiple "
"inference objects, they all have to point to the "
"same object. "
"Alternatively, if you pass the kernel keyword arguments "
"to only "
"one inference object in the group, they will be applied "
"for the whole group."
)
group.position_keys.append(name)
else:
kernel_groups[group_name] = _KernelGroup(
kernel=inference.kernel,
kwargs=inference.kernel_kwargs,
position_keys=[name],
)
return kernel_groups
[docs]
def get_kernel_list(self) -> list[Kernel]:
"""
Construct the list of MCMC kernels from kernel groups.
Returns
-------
A list of initialized kernel instances ready to be added to the MCMC engine.
"""
kernel_groups = self.get_kernel_groups()
kernel_list = [
g.kernel(g.position_keys, **g.kwargs) # type: ignore
for g in kernel_groups.values()
]
return kernel_list
[docs]
def get_jitter_functions(self) -> JitterFunctions:
"""
Collect jitter functions for model variables that define a jitter distribution.
Returns
-------
A dictionary mapping variable names to their jitter application functions.
"""
jitter_functions: JitterFunctions = {}
for name, var in self.model.vars.items():
inference = self.get_spec(var)
if inference is not None and inference.jitter_dist is not None:
jitter_functions[name] = inference.apply_jitter
return jitter_functions
[docs]
def get_engine_builder(
self,
seed: int,
num_chains: int,
apply_jitter: bool = True,
) -> EngineBuilder:
"""
Create and configure an `EngineBuilder` for MCMC sampling.
Parameters
----------
seed
Random seed for reproducibility.
num_chains
Number of MCMC chains.
apply_jitter
Whether to apply jitter to the initial states, by default True.
Returns
-------
EngineBuilder
A configured ``EngineBuilder`` instance.
"""
eb = EngineBuilder(seed=seed, num_chains=num_chains)
eb.set_model(LieselInterface(self.model))
eb.set_initial_values(self.model.state)
for kernel in self.get_kernel_list():
eb.add_kernel(kernel)
if apply_jitter:
eb.set_jitter_fns(self.get_jitter_functions())
return eb
@dataclass
class _KernelGroup:
kernel: Callable[..., Kernel]
kwargs: dict[str, Any] = field(default_factory=dict)
position_keys: list[str] = field(default_factory=list)
P = ParamSpec("P")
class KernelFactory(Protocol[P]):
"""Create a kernel instance based on the provided position keys and arguments."""
def __call__(
self, position_keys: list[str], *args: P.args, **kwargs: P.kwargs
) -> Kernel: ...
[docs]
@dataclass
class MCMCSpec:
"""
Specification for the MCMC kernel and optional jitter distribution associated with a
model variable.
Parameters
----------
kernel
A KernelFactory that returns a ``Kernel`` instance when provided with position
keys and keyword arguments.
kernel_kwargs
Additional keyword arguments to be passed to the kernel callable.
kernel_group
Name of the kernel group this variable belongs to. Variables in the same group \
must share the same kernel type and arguments.
jitter_dist
A TensorFlow Probability distribution used to apply random jitter to the \
initial value of the variable.
jitter_method
The type of jitter to be applied. This can be one of the following:
- `none`: No jitter is applied.
- `additive`: Additive jitter is applied.
- `multiplicative`: Multiplicative jitter is applied.
- `replacement`: Value is replaced when jitter is applied.
Examples
--------
.. rubric:: Liesel Workflow
For this example, we import ``tensorflow_probability`` as follows:
>>> import tensorflow_probability.substrates.jax.distributions as tfd
First, we set up a minimal model:
>>> mu = lsl.Var.new_param(0.0, name="mu", inference=gs.MCMCSpec(gs.NUTSKernel))
>>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y")
>>> model = lsl.Model([y])
Now we initialize the EngineBuilder and set the desired number of warmup and
posterior samples:
>>> builder = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
>>> builder.add_adaptation(1000)
>>> builder.add_posterior(1000)
Finally, we build the engine:
>>> engine = builder.build()
"""
def __post_init__(self) -> None:
if self.jitter_method not in self._JITTER_METHODS:
raise ValueError(
f"Invalid jitter method: {self.jitter_method}. "
f"Expected one of {self._JITTER_METHODS}."
)
_JITTER_METHODS = ["additive", "multiplicative", "replacement"]
kernel: KernelFactory
kernel_kwargs: dict[str, Any] = field(default_factory=dict)
kernel_group: str | None = None
jitter_dist: tfd.Distribution | None = None
jitter_method: Literal["additive", "multiplicative", "replacement"] = "additive"
def __repr__(self) -> str:
return f"{type(self).__name__}({self.kernel}, {self.kernel_group=})"
[docs]
def apply_jitter(self, seed: KeyArray, value: Array) -> Array:
"""
Apply random jitter to a given value using the specified jitter distribution.
If a jitter distribution is set, a random sample from the distribution is added
to the original value. If no jitter distribution is set, the original value is
returned unchanged.
Parameters
----------
seed
A PRNG key used for random sampling.
value
The value to which jitter should be applied.
Returns
-------
The jittered value with the same shape as the input.
"""
if self.jitter_dist is None:
return value
# check compatibility of shapes
if (
self.jitter_dist.batch_shape + self.jitter_dist.event_shape != value.shape
) and (
self.jitter_dist.batch_shape.rank + self.jitter_dist.event_shape.rank > 0
):
raise ValueError(
f"Jitter distribution shapes "
f"(batch shape {self.jitter_dist.batch_shape} "
f"and event shape {self.jitter_dist.event_shape}) "
f"do not match variable shape {value.shape}."
)
sample_shape = (
value.shape
if self.jitter_dist.batch_shape + self.jitter_dist.event_shape == ()
else ()
)
jitter = self.jitter_dist.sample(sample_shape=sample_shape, seed=seed)
match self.jitter_method:
case "additive":
value = value + jitter
case "multiplicative":
value = value * jitter
case "replacement":
value = jitter
case _:
assert_never(self.jitter_method)
return value