from __future__ import annotations
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, Protocol, assert_never
import tensorflow_probability.substrates.jax.distributions as tfd
from .builder import EngineBuilder
from .engine import SamplingResults
from .interface import LieselInterface
from .types import Array, JitterFunctions, Kernel, KeyArray
if TYPE_CHECKING:
from liesel.model import Model, Var
logger = logging.getLogger(__name__)
[docs]
@dataclass
class LieselMCMC:
"""
Manages the setup of MCMC specifications for a Liesel model.
Parameters
----------
model
The Liesel model object containing the variables and their inference \
specifications.
which
A named inference configuration to use. If None, the default inference \
attached to each variable is used.
Examples
--------
.. rubric:: Liesel Workflow
For this example, we import ``tensorflow_probability`` as follows:
>>> import tensorflow_probability.substrates.jax.distributions as tfd
First, we set up a minimal model:
>>> mu = lsl.Var.new_param(0.0, name="mu", inference=gs.MCMCSpec(gs.NUTSKernel))
>>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y")
>>> model = lsl.Model([y])
Now we run MCMC:
>>> results = gs.LieselMCMC(model).run_mcmc( # doctest: +SKIP
... seed=1, num_chains=4, adaptation=250, posterior=100 # doctest: +SKIP
... ) # doctest: +SKIP
The function returns a :class:`.SamplingResults` object.
.. rubric:: More control
For additional control, we initialize an :class:`.EngineBuilder` and continue
from there.
>>> builder = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
>>> builder.add_adaptation(1000)
>>> builder.add_posterior(1000)
"""
model: Model
which: str | None = None
[docs]
def get_spec(self, var: Var) -> MCMCSpec | None:
"""
Retrieve the MCMC specification for a given variable.
Parameters
----------
var
The model variable for which to get the MCMC specification.
Returns
-------
The MCMC specification if available, otherwise None.
Raises
------
ValueError
If the inference attached to the variable is not of type ``MCMCSpec``.
"""
inference = var.get_inference(self.which)
if inference is None:
return inference
if not isinstance(inference, MCMCSpec):
raise ValueError(
f"Attribute 'inference' of variable {var} is of type"
f" {type(inference)}, but expected type '{MCMCSpec}'."
)
return inference
[docs]
def get_kernel_groups(self) -> dict[str, _KernelGroup]:
"""
Collect and organize model variables into kernel groups.
Returns
-------
A dictionary mapping group names or variable names to their corresponding \
kernel group specifications.
Raises
------
ValueError
If variables in the same kernel group have inconsistent kernels or kernel \
arguments.
"""
vars_ = {}
for k in reversed(list(self.model.vars.keys())):
vars_[k] = self.model.vars[k]
kernel_groups: dict[str, _KernelGroup] = {}
for name, var in vars_.items():
inference = self.get_spec(var)
if not inference:
continue
group_name = inference.kernel_group
if group_name is None:
kernel_groups[name] = _KernelGroup(
kernel=inference.kernel,
kwargs=inference.kernel_kwargs,
position_keys=[name],
order=inference.order,
)
elif group_name in kernel_groups:
group = kernel_groups[group_name]
same_kernel = group.kernel is inference.kernel
if not same_kernel:
raise ValueError(
"Found incoherent kernel classes for kernel group"
f" {group_name}."
)
if not inference.kernel_kwargs:
pass
elif not group.kwargs:
group.kwargs = inference.kernel_kwargs
else:
if group.kwargs is not inference.kernel_kwargs:
raise ValueError(
"Found incoherent kernel keyword arguments for "
f"kernel group {group_name}. "
"When supplying kernel keyword arguments for multiple "
"inference objects, they all have to point to the "
"same object. "
"Alternatively, if you pass the kernel keyword arguments "
"to only "
"one inference object in the group, they will be applied "
"for the whole group."
)
group.position_keys.append(name)
else:
kernel_groups[group_name] = _KernelGroup(
kernel=inference.kernel,
kwargs=inference.kernel_kwargs,
position_keys=[name],
order=inference.order,
)
kernel_groups = dict(
sorted(kernel_groups.items(), key=lambda item: item[1].order)
)
return kernel_groups
[docs]
def get_kernel_list(self) -> list[Kernel]:
"""
Construct the list of MCMC kernels from kernel groups.
Returns
-------
A list of initialized kernel instances ready to be added to the MCMC engine.
"""
kernel_groups = self.get_kernel_groups()
kernel_list = [
g.kernel(g.position_keys, **g.kwargs) # type: ignore
for g in kernel_groups.values()
]
return kernel_list
[docs]
def get_jitter_functions(self) -> JitterFunctions:
"""
Collect jitter functions for model variables that define a jitter distribution.
Returns
-------
A dictionary mapping variable names to their jitter application functions.
"""
jitter_functions: JitterFunctions = {}
for name, var in self.model.vars.items():
inference = self.get_spec(var)
if inference is not None and inference.jitter_dist is not None:
jitter_functions[name] = inference.apply_jitter
return jitter_functions
[docs]
def get_engine_builder(
self,
seed: int,
num_chains: int,
apply_jitter: bool = True,
) -> EngineBuilder:
"""
Create and configure an `EngineBuilder` for MCMC sampling.
Parameters
----------
seed
Random seed for reproducibility.
num_chains
Number of MCMC chains.
apply_jitter
Whether to apply jitter to the initial states, by default True. Note that
initial values for a variable will only jittered if the
:class:`.MCMCSpec` for this variable was supplied with a ``jitter_dist``.
Returns
-------
EngineBuilder
A configured ``EngineBuilder`` instance.
"""
self.validate_inference_specs()
eb = EngineBuilder(seed=seed, num_chains=num_chains)
eb.set_model(LieselInterface(self.model))
eb.set_initial_values(self.model.state)
for kernel in self.get_kernel_list():
eb.add_kernel(kernel)
if apply_jitter:
eb.set_jitter_fns(self.get_jitter_functions())
return eb
[docs]
def validate_inference_specs(self) -> None:
"""
Logs a warning if there are any parameters in the model that have no inference
specification for MCMC.
"""
no_inference: list[str] = []
for name, var in self.model.parameters.items():
if isinstance(var.inference, MCMCSpec):
continue
elif var.inference is None:
no_inference.append(name)
elif hasattr(var.inference, "values"):
specs = list(var.inference.values())
for spec in specs:
if isinstance(spec, MCMCSpec):
continue
# triggers only if None of the specs in the inference dict was an
# MCMCSpec
no_inference.append(name)
else:
no_inference.append(name)
for name in no_inference:
logger.warning(
f"No inference specification defined for {self.model.vars[name]}. "
"If you do not add a kernel for this parameter manually to an "
"EngineBuilder, it will not be"
" sampled."
)
[docs]
def run_for_epochs(
self,
*,
seed: int,
num_chains: int,
adaptation: int,
posterior: int,
burnin: int = 0,
adaptation_thinning: int = 1,
burnin_thinning: int = 1,
posterior_thinning: int = 1,
apply_jitter: bool = True,
store_kernel_states: bool = False,
show_progress: bool = True,
positions_included: list[str] | None = None,
positions_excluded: list[str] | None = None,
save_path: str | Path | None = None,
) -> SamplingResults:
"""
Shorthand method for quickly running MCMC for a set number of epochs.
Parameters
----------
seed
Random seed for reproducibility.
num_chains
Number of MCMC chains.
adaptation, burnin, posterior
Number of samples to be drawn in the respective epoch.
adaptation_thinning, burnin_thinning, posterior_thinning
Thinning to be applied in the respective epoch.
apply_jitter
Whether to apply jitter to the initial states, by default True. Note that
initial values for a variable will only jittered if the
:class:`.MCMCSpec` for this variable was supplied with a ``jitter_dist``.
Think of this argument rather as an off-switch of existing jittering.
store_kernel_states
Whether to store kernel states in sampling results, which may be useful
for debugging.
show_progress
Whether to show progress bars during sampling.
positions_included
List of additional position keys that should be tracked, see
:attr:`.EngineBuilder.positions_included`.
positions_excluded
List of position keys that should not be tracked. Excluded keys override
additional keys see :attr:`.EngineBuilder.positions_excluded`.
save_path
Filepath to a pickle file in which results should be saved. If the file
exists, results are loaded from this file and no sampling occurs.
Warnings
---------
This method is *only* appropriate, if your MCMC algorithm is fully specified via
:class:`.MCMCSpec` objects in the :attr:`.Var.inference` attributes of the
variables in your model.
See Also
---------
.get_engine_builder : Method to obtain an :class:`.EngineBuilder` from the
LieselMCMC object. The :class:`.EngineBuilder` allows for more detailed
custom configuration; for example you can add additional MCMC kernels via
:meth:`.EngineBuilder.add_kernel`.
Notes
------
The method is equivalent to the following code::
eb = LieselMCMC(model).get_engine_builder(
seed=seed, num_chains=num_chains, apply_jitter=apply_jitter
)
eb.store_kernel_states = store_kernel_states
eb.positions_included = positions_included or []
eb.positions_excluded = positions_excluded or []
eb.show_progress = show_progress
if adaptation > 0:
eb.add_adaptation(adaptation, adaptation_thinning)
if burnin > 0:
eb.add_burnin(burnin, burnin_thinning)
eb.add_posterior(posterior, posterior_thinning)
engine = eb.build()
engine.sample_all_epochs()
engine.get_results()
"""
if save_path is not None:
fp = Path(save_path)
logger.info(f"Save path provided: {fp}.")
if fp.exists():
logger.info(f"Loading results from {fp}. No sampling is happening.")
return SamplingResults.pkl_load(fp)
eb = self.get_engine_builder(
seed=seed, num_chains=num_chains, apply_jitter=apply_jitter
)
eb.store_kernel_states = store_kernel_states
eb.positions_included = positions_included or []
eb.positions_excluded = positions_excluded or []
eb.show_progress = show_progress
if adaptation > 0:
eb.add_adaptation(adaptation, thinning=adaptation_thinning)
if burnin > 0:
eb.add_burnin(burnin, burnin_thinning)
eb.add_posterior(posterior, posterior_thinning)
engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
if save_path is not None:
fp = Path(save_path)
logger.info(f"Saving results to save path: {fp}.")
results.pkl_save(fp)
return results
@dataclass
class _KernelGroup:
kernel: Callable[..., Kernel]
kwargs: dict[str, Any] = field(default_factory=dict)
position_keys: list[str] = field(default_factory=list)
order: int = 99
P = ParamSpec("P")
class KernelFactory(Protocol[P]):
"""Create a kernel instance based on the provided position keys and arguments."""
def __call__(
self, position_keys: list[str], *args: P.args, **kwargs: P.kwargs
) -> Kernel: ...
[docs]
@dataclass
class MCMCSpec:
"""
Specification for the MCMC kernel and optional jitter distribution associated with a
model variable.
Parameters
----------
kernel
A KernelFactory that returns a ``Kernel`` instance when provided with position
keys and keyword arguments.
kernel_kwargs
Additional keyword arguments to be passed to the kernel callable.
kernel_group
Name of the kernel group this variable belongs to. Variables in the same group \
must share the same kernel type and arguments.
jitter_dist
A TensorFlow Probability distribution used to apply random jitter to the \
initial value of the variable.
jitter_method
The type of jitter to be applied. This can be one of the following: - `none`: No
jitter is applied. - `additive`: Additive jitter is applied. - `multiplicative`:
Multiplicative jitter is applied. - `replacement`: Value is replaced when jitter
is applied.
order
If you want to change the order in which parameter blocks are sampled. Blocks
will be ordered by default based on the topological order of the graph (from the
bottom up; i.e. the kernels for sampling parameters closest to the graph's leaf
nodes/responses come first), which is often a sensible default. After that,
blocks will be ordered based on the integer provided here. The kernel with the
smallest ``order`` integer will be used first.
Examples
--------
.. rubric:: Liesel Workflow
For this example, we import ``tensorflow_probability`` as follows:
>>> import tensorflow_probability.substrates.jax.distributions as tfd
First, we set up a minimal model:
>>> mu = lsl.Var.new_param(0.0, name="mu", inference=gs.MCMCSpec(gs.NUTSKernel))
>>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y")
>>> model = lsl.Model([y])
Now we initialize the EngineBuilder and set the desired number of warmup and
posterior samples:
>>> builder = gs.LieselMCMC(model).get_engine_builder(seed=1, num_chains=4)
>>> builder.add_adaptation(1000)
>>> builder.add_posterior(1000)
Finally, we build the engine:
>>> engine = builder.build()
"""
def __post_init__(self) -> None:
if self.jitter_method not in self._JITTER_METHODS:
raise ValueError(
f"Invalid jitter method: {self.jitter_method}. "
f"Expected one of {self._JITTER_METHODS}."
)
_JITTER_METHODS = ["additive", "multiplicative", "replacement"]
kernel: KernelFactory
kernel_kwargs: dict[str, Any] = field(default_factory=dict)
kernel_group: str | None = None
jitter_dist: tfd.Distribution | None = None
jitter_method: Literal["additive", "multiplicative", "replacement"] = "additive"
order: int = 99
def __repr__(self) -> str:
return f"{type(self).__name__}({self.kernel}, {self.kernel_group=})"
[docs]
def apply_jitter(self, seed: KeyArray, value: Array) -> Array:
"""
Apply random jitter to a given value using the specified jitter distribution.
If a jitter distribution is set, a random sample from the distribution is added
to the original value. If no jitter distribution is set, the original value is
returned unchanged.
Parameters
----------
seed
A PRNG key used for random sampling.
value
The value to which jitter should be applied.
Returns
-------
The jittered value with the same shape as the input.
"""
if self.jitter_dist is None:
return value
# check compatibility of shapes
if (
self.jitter_dist.batch_shape + self.jitter_dist.event_shape != value.shape
) and (
self.jitter_dist.batch_shape.rank + self.jitter_dist.event_shape.rank > 0
):
raise ValueError(
f"Jitter distribution shapes "
f"(batch shape {self.jitter_dist.batch_shape} "
f"and event shape {self.jitter_dist.event_shape}) "
f"do not match variable shape {value.shape}."
)
sample_shape = (
value.shape
if self.jitter_dist.batch_shape + self.jitter_dist.event_shape == ()
else ()
)
jitter = self.jitter_dist.sample(sample_shape=sample_shape, seed=seed)
match self.jitter_method:
case "additive":
value = value + jitter
case "multiplicative":
value = value * jitter
case "replacement":
value = jitter
case _:
assert_never(self.jitter_method)
return value