"""
MCMC engine
This module is experimental. Expect API changes.
"""
# mypy: check-untyped-defs
from __future__ import annotations
import logging
import pickle
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from functools import partial
from typing import NamedTuple, cast
import jax
import jax.lax
import jax.numpy as jnp
import jax.random
import jax.tree_util
import numpy as np
from deprecated.sphinx import deprecated
from tqdm import tqdm
from liesel.option import Option
from .chain import Chain, EpochChainManager, ListChain
from .epoch import EpochConfig, EpochManager, EpochState, EpochType
from .kernel_sequence import KernelSequence, KernelStates, TransitionInfos, TuningInfos
from .pytree import as_strong_pytree, register_dataclass_as_pytree
from .types import (
Array,
GeneratedQuantity,
KeyArray,
ModelInterface,
ModelState,
Position,
PyTree,
QuantityGenerator,
TransitionInfo,
)
logger = logging.getLogger(__name__)
[docs]class KernelErrorLog(NamedTuple):
"""
Holds the number of the transitions in which an error in at least one chain occured
and an array with the error code for each chain.
Additionally, the kernel identifier is specified and optionally the cls of the
kernel.
"""
kernel_ident: str
kernel_cls: Option[type] # needed to use the error book
transition: np.ndarray
"""1-D array (time)."""
error_codes: np.ndarray
"""2-D array (chain, time)."""
ErrorLog = dict[str, KernelErrorLog]
def _expand_and_stack(chunk, *rest):
chunks = [chunk]
chunks.extend(rest)
expended_chunks = [jnp.expand_dims(chunk, 0) for chunk in chunks]
return jnp.concatenate(expended_chunks, axis=0)
[docs]@deprecated(
version="0.1.0", reason="Use the functions from liesel.goose.pytree instead."
)
def stack_for_multi(chunks: list):
"""
Combine identically structured pytrees to be used in multichain.
The function adds a new dimension (axis 0) to each leaf and stacks the leafs
along the new axis.
"""
warnings.warn(
"``stack_for_multi`` is deprecated. Please use the functions"
" in the :mod:`.pytree` module.",
DeprecationWarning,
)
return jax.tree_util.tree_map(
lambda x, *xs: _expand_and_stack(x, *xs), chunks[0], *chunks[1:]
)
@partial(jax.jit, static_argnums=1)
def _split_keys(keys, n):
keys = jax.lax.map(lambda key: jax.random.split(key, n), keys)
return keys
def _initialze_prng(seed: int | KeyArray) -> KeyArray:
if jnp.isscalar(seed):
return jax.random.PRNGKey(seed) # type: ignore
elif jnp.shape(seed) == (2,): # type: ignore
return seed # type: ignore
else:
raise ValueError("Seed has an unsupported shape")
def _add_time_dimension(x: PyTree) -> PyTree:
"""
Adds a new dimension for time to each leaf.
The returned tree has the same structure with one additional dimension of
size 1. The new dimension is ``axis=1``. Each leaf must have at least one
dimension (representing the chain index).
"""
initial_position = jax.tree_util.tree_map(
lambda y, *_ys: jnp.expand_dims(y, 1),
x,
)
return initial_position
[docs]@register_dataclass_as_pytree
@dataclass(frozen=True)
class Carry:
"""
Holds the state that needs to be carried between MCMC interations.
"""
kernel_states: KernelStates
model_state: ModelState
epoch: EpochState
[docs]@dataclass
class SamplingResults:
"""
Contains the results of the MCMC engine.
Easy access to the samples is provided via the methods
:func:`.get_samples` and :func:`.get_posterior_samples`.
"""
positions: EpochChainManager
"""EpochChainManager giving access to monitored variables."""
transition_infos: EpochChainManager
"""EpochChainManager storing all transition infos."""
generated_quantities: Option[EpochChainManager]
"""
Option[EpochChainManager] storing all generated_quantities.
is_none(), if no quantities have been generated.
"""
tuning_infos: Option[Chain]
"""
Option[Chain] storing all tuning infos.
is_none(), if no tuning was executed
"""
kernel_states: Option[EpochChainManager]
"""
Option[EpochChainManager] holds all kernel states.
is_none(), if monitoring kernel states was not requested.
"""
full_model_states: Option[EpochChainManager]
"""
Option[EpochChainManager] holds the full model state of each iteration.
is_none(), if monitoring was not explicitly requested.
"""
kernel_classes: Option[dict[str, type]]
"""
Optional map of kernel identifier to the respective kernel type.
"""
kernels_by_pos_key: Option[dict[str, str]]
"""
Optional map of position key to identifier of the for sampling responsible
kernel.
"""
[docs] def get_samples(self) -> Position:
"""
Returns a dictionary of all samples for all parameters included in the
position.
"""
opt: Option[Position] = self.positions.combine_all()
return opt.expect(f"No samples in {repr(self)}")
[docs] def get_posterior_samples(self) -> Position:
"""
Returns a dictionary of posterior samples for all parameters included in the
position.
"""
opt = self.positions.combine_filtered(
lambda config: config.type == EpochType.POSTERIOR
)
return opt.expect(f"No posterior samples in {repr(self)}")
[docs] def get_kernels_by_pos_key(self) -> dict[str, str]:
"""
Returns a dict, identifying the kernel used to sample each position.
The dict has the format ``{"position name": "kernel identifier"}``.
"""
return self.kernels_by_pos_key.expect(
f"No position-kernel associations in {repr(self)}"
)
[docs] def get_posterior_transition_infos(self) -> dict[str, TransitionInfo]:
"""
Returns a dictionary of posterior transition information for all parameters
included in the position.
"""
opt = self.transition_infos.combine_filtered(
lambda config: config.type == EpochType.POSTERIOR
)
return opt.expect(f"No posterior transition infos in {repr(self)}")
[docs] def get_tuning_times(self) -> Option[Array]:
"""
Returns array of tuning times.
"""
if self.tuning_infos.is_none():
return Option.none()
# opt_tis is not None since self.tuning_infos is not None
opt_tis = self.tuning_infos.unwrap().get().unwrap()
time: Array = next(iter(opt_tis.values())).time
return Option(time)
[docs] def get_error_log(self, posterior_only=False) -> Option[ErrorLog]:
"""
Returns the error log that is an dict[kernel_name, KernelErrorLog]
"""
opt: Option[TransitionInfos]
if posterior_only:
opt = self.transition_infos.combine_filtered(
lambda config: config.type == EpochType.POSTERIOR
)
if opt.is_none():
return Option(None)
else:
tis = opt.expect(f"No posterior transition infos in {repr(self)}")
else:
opt = self.transition_infos.combine_all()
tis = opt.expect(f"No transition infos in {repr(self)}")
error_log: ErrorLog = {}
for ker_name in tis:
mask = np.any(tis[ker_name].error_code != 0, axis=0)
transition: np.ndarray = np.where(mask)[0]
# cast is ok since the object has more dimensions in the leaf
error_codes: np.ndarray = cast(np.ndarray, tis[ker_name].error_code)[
:, mask
]
cls = self.kernel_classes.map(lambda d: d[ker_name])
error_log[ker_name] = KernelErrorLog(ker_name, cls, transition, error_codes)
return Option(error_log)
[docs] def pkl_save(self, path) -> None:
"""Save result as a pickled object under :attr:`.path`."""
with open(path, "wb") as f:
pickle.dump(self, f)
[docs] @staticmethod
def pkl_load(path) -> SamplingResults:
"""Loads the pickled object from :attr:`.path`."""
with open(path, "rb") as f:
return pickle.load(f)
[docs]@deprecated(reason="Use SamplingResults", version="0.1.4")
class SamplingResult(SamplingResults):
"""Alias of :class:`.SamplingResults` for backwards compatibility."""
positions: EpochChainManager
transition_infos: EpochChainManager
generated_quantities: Option[EpochChainManager]
tuning_infos: Option[Chain]
kernel_states: Option[EpochChainManager]
full_model_states: Option[EpochChainManager]
kernel_classes: Option[dict[str, type]]
kernels_by_pos_key: Option[dict[str, str]]
[docs]class Engine:
"""MCMC engine capable of combining multiple transition kernels."""
def __init__(
self,
seeds: KeyArray,
model_states: ModelState,
kernel_sequence: KernelSequence,
epoch_configs: Sequence[EpochConfig],
jitted_sample_duration: int,
model: ModelInterface,
position_keys: Sequence[str] | None,
minimize_transition_infos: bool = False,
store_kernel_states: bool = False,
quantity_generators: Sequence[QuantityGenerator] = [],
show_progress: bool = True,
):
# fill slots that can be filled directly
self._inital_states = model_states
self._seeds = seeds
self._jitted_sample_duration = jitted_sample_duration
self._minimize_transition_infos = minimize_transition_infos
self._store_kernel_states = store_kernel_states
self._model_states = model_states
self._quantity_generators = quantity_generators
self._show_progress = show_progress
self._kernel_sequence = kernel_sequence
self._epoch_manager = EpochManager(epoch_configs)
self._warmup_has_ended = False
if not position_keys:
position_keys = [
key
for ker in self._kernel_sequence._kernels # FIXME: use of private field
for key in ker.position_keys
]
self._position_keys = position_keys
self._model = model
# feed in history if at least one kernel requires history for tuning
#
# FIXME: automatically fetch position keys
#
# fetch kernels' position keys and add them automatically to track them
# in the position chain
self._history_required_for_tuning = any(
[ker.needs_history for ker in self._kernel_sequence._kernels]
) # FIXME: use of private field
self._prng_key = seeds
# setup storage
self._position_chain: EpochChainManager = EpochChainManager(apply_thinning=True)
self._transition_info_chain: EpochChainManager = EpochChainManager()
self._tuning_info_chain: ListChain = ListChain()
self._kernel_state_chain: EpochChainManager = EpochChainManager()
self._quantities_chain: EpochChainManager = EpochChainManager(
apply_thinning=True
)
# initialize kernel state
keys = self._split_prng_key_one()
self._kernel_states = jax.vmap(self._kernel_sequence.init_states)(
keys, self._model_states
)
# current epoch
self._epoch: EpochState | None = None
# jit sample function
self._sample_many_jitted = jax.jit(
jax.vmap(
self._sample_many,
in_axes=(0, None, 0, 0),
out_axes=(None, 0, 0, 0, 0, 0, 0),
)
)
@property
def current_epoch(self) -> EpochState:
"""
Returns the current epoch.
Raises a :exc:`.RuntimeError` if no epoch is active.
"""
if self._epoch is None:
raise RuntimeError("No active epoch")
return self._epoch
[docs] def sample_all_epochs(self):
"""
Runs sampling for all remaining epochs.
Auto-tuning methods are called automatically.
"""
while self._epoch_manager.has_more():
self.sample_next_epoch()
[docs] def sample_next_epoch(self):
"""Runs sampling for the next epoch assuming no epoch is active."""
self._start_epoch()
# special treatment for the initial values
if self.current_epoch.config.type == EpochType.INITIAL_VALUES:
self._handle_inital_values_epoch()
return
self._kernel_start_epoch()
duration = self.current_epoch.config.duration
epoch_type = self.current_epoch.config.type.name
jitted = self._jitted_sample_duration
if self._show_progress:
logger.info(
f"Starting epoch: {epoch_type}, {duration} transitions, "
f"{jitted} jitted together"
)
self._sample_for_duration(duration=duration)
self._end_epoch()
[docs] def append_epoch(self, epoch: EpochConfig):
"""Appends an epoch to the epochs that should be sampled."""
self._epoch_manager.append(epoch)
[docs] def is_sampling_done(self) -> bool:
"""Returns true if all configured epochs have been sampled."""
return not self._epoch_manager.has_more()
[docs] def get_results(self) -> SamplingResults:
"""Returns the results of the sampling process."""
if self._store_kernel_states:
ksc = self._kernel_state_chain
else:
ksc = None
if self._quantity_generators:
gqs = self._quantities_chain
else:
gqs = None
kernels = self._kernel_sequence.get_kernels()
kernels_cls: dict[str, type] = {ker.identifier: type(ker) for ker in kernels}
kernels_by_position: dict[str, str] = dict()
for kernel in kernels:
kernels_by_position.update(
{key: kernel.identifier for key in kernel.position_keys}
)
return SamplingResults(
positions=self._position_chain,
transition_infos=self._transition_info_chain,
generated_quantities=Option(gqs),
tuning_infos=Option(self._tuning_info_chain),
kernel_states=Option(ksc),
full_model_states=Option(None),
kernel_classes=Option(kernels_cls),
kernels_by_pos_key=Option(kernels_by_position),
)
def _split_prng_key(self, n: int = 1) -> KeyArray:
keys = _split_keys(self._prng_key, n + 1)
self._prng_key = keys[:, 0, :]
return keys[:, 1:, :]
def _split_prng_key_one(self) -> KeyArray:
key = self._split_prng_key(1)
return key[:, 0, :]
def _generate_quantity(self):
if not self._quantity_generators:
return None
quants = {}
for qg in self._quantity_generators:
key = self._split_prng_key_one()
gen_f = jax.vmap(qg.generate, in_axes=(0, 0, None))
quant = gen_f(key, self._model_states, self.current_epoch)
quants[qg.identifier] = quant
return quants
def _handle_inital_values_epoch(self):
assert self.current_epoch.config.type == EpochType.INITIAL_VALUES
self.current_epoch.advance_time(1)
initial_position = _add_time_dimension(
x=jax.vmap(self._model.extract_position, in_axes=(None, 0))(
self._position_keys, self._model_states
),
)
self._position_chain.append(initial_position)
if self._store_kernel_states:
ks = _add_time_dimension(x=self._kernel_states)
self._kernel_state_chain.append(ks)
if self._quantity_generators:
quants = self._generate_quantity()
quants = _add_time_dimension(x=quants)
self._quantities_chain.append(quants)
self._epoch = None
def _start_epoch(self):
"""Advances to the next epoch."""
if self._epoch is not None:
raise RuntimeError("Epoch is active and not completed")
self._epoch = self._epoch_manager.next()
# invoke end_warmup() for the first non-warmup epoch
if (
not self._warmup_has_ended
and self.current_epoch.config.type == EpochType.POSTERIOR
):
self._end_warmup()
# advance chains to next epoch
self._position_chain.advance_epoch(self.current_epoch.config)
self._transition_info_chain.advance_epoch(self.current_epoch.config)
self._kernel_state_chain.advance_epoch(self.current_epoch.config)
self._quantities_chain.advance_epoch(self.current_epoch.config)
def _kernel_start_epoch(self):
"""Inform kernels about new epoch."""
keys = self._split_prng_key_one()
self._kernel_states = jax.vmap(
self._kernel_sequence.start_epoch, in_axes=(0, 0, 0, None)
)(keys, self._kernel_states, self._model_states, self.current_epoch)
def _end_warmup(self):
"""
Ends the warmup sequence.
Calls :func:`.end_warmup` for each kernel. From now on, only epochs of type
posterior can follow.
"""
keys = self._split_prng_key_one()
tuning_infos: TuningInfos | None = self._tuning_info_chain.get().value
end_warmup_output = jax.vmap(self._kernel_sequence.end_warmup)(
keys, self._kernel_states, self._model_states, tuning_infos
)
self._kernel_states = end_warmup_output.kernel_states
# add warnings for the user if there are any non-zero error-code
for kid, ec in end_warmup_output.error_codes.items():
if jnp.any(ec != 0):
logger.warning(f"Warmup error code for {kid}: {ec}")
logger.info("Finished warmup")
def _end_epoch(self):
"""
End epoch.
Informs kernels about the end of the epoch and initializes the tuning
if required.
"""
# ensure that an epoch is active
epoch = self.current_epoch
# inform kernels about end of epoch
end_keys = self._split_prng_key_one()
self._kernel_states = jax.vmap(
self._kernel_sequence.end_epoch, in_axes=(0, 0, 0, None)
)(end_keys, self._kernel_states, self._model_states, epoch)
self._tune_kernels(epoch)
if self._show_progress:
ti_option = self._transition_info_chain.get_current_chain().get()
def count_non_zero_error_codes(tis: TransitionInfos):
cts = {}
for kernel_id, ti in tis.items():
nzero = jnp.sum(ti.error_code != 0, axis=1)
ntrans = ti.error_code.shape[1] # type: ignore
cts[kernel_id] = (nzero, ntrans)
return cts
error_info: dict[str, tuple[Array, int]] = ti_option.map_or(
{}, count_non_zero_error_codes
)
for kid, kcts in error_info.items():
if jnp.any(kcts[0] != 0):
logger.warning(
f"Errors per chain for {kid}: "
f"{', '.join(map(str, kcts[0]))} / {kcts[1]} transitions"
)
logger.info("Finished epoch")
# no epoch is active anymore
self._epoch = None
def _tune_kernels(self, epoch: EpochState):
"""Trigger tuning if epoch is an adaptation phase."""
if EpochType.is_adaptation(epoch.config.type):
tune_keys = self._split_prng_key_one()
if self._history_required_for_tuning:
history = (
self._position_chain.get_current_chain()
.get()
.expect("The history must contain samples.")
)
else:
history = None
tune_output = jax.vmap(
self._kernel_sequence.tune, in_axes=(0, 0, 0, None, 0)
)(tune_keys, self._kernel_states, self._model_states, epoch, history)
self._kernel_states = tune_output.kernel_states
# we need to add the time dimension
self._tuning_info_chain.append(_add_time_dimension(x=tune_output.infos))
def _sample_many(
self,
keys: KeyArray,
epoch: EpochState,
kernel_states: KernelStates,
model_state: ModelState,
) -> tuple[
EpochState,
KernelStates,
ModelState,
Position,
TransitionInfos,
None | KernelStates,
None | dict[str, GeneratedQuantity],
]:
def scan_f(
carry: Carry, key: KeyArray
) -> tuple[
Carry,
tuple[
Position,
TransitionInfos,
None | KernelStates,
None | dict[str, GeneratedQuantity],
],
]:
key_trans, key_quants = jax.random.split(key)
epoch = carry.epoch
out = self._kernel_sequence.transition(
key_trans, carry.kernel_states, carry.model_state, epoch
)
epoch.advance_time(1)
new_carry = Carry(out.kernel_states, out.model_state, epoch)
# extract the position specified to store in chain
position = self._model.extract_position(
self._position_keys, out.model_state
)
# minimize transition infos if requested
tinfos = out.infos
if self._minimize_transition_infos:
for id in tinfos.keys():
tinfos[id] = tinfos[id].minimize()
ks = None
if self._store_kernel_states:
ks = new_carry.kernel_states
quants = None
if self._quantity_generators:
quants = {}
keys = jax.random.split(key_quants, len(self._quantity_generators))
for i, qg in enumerate(self._quantity_generators):
key = keys[i]
quant = qg.generate(key, out.model_state, epoch)
quants[qg.identifier] = quant
return new_carry, (position, tinfos, ks, quants)
inital_carry = Carry(kernel_states, model_state, epoch)
carry, chain = jax.lax.scan(scan_f, inital_carry, keys)
kernel_states = carry.kernel_states
model_state = carry.model_state
epoch = carry.epoch
return (
epoch,
kernel_states,
model_state,
chain[0],
chain[1],
chain[2],
chain[3],
)
def _sample_for_duration(self, duration: int):
if self.current_epoch.time_left() < duration:
raise RuntimeError("Not enough time left in epoch")
if duration % self._jitted_sample_duration:
raise RuntimeError(
f"Duration {duration} is not a multiple of the "
f"jitted sampling duration {self._jitted_sample_duration}"
)
# convert to non-weak device arrays to avoid recompilation
self._epoch = as_strong_pytree(self._epoch)
self._kernel_states = as_strong_pytree(self._kernel_states)
self._model_states = as_strong_pytree(self._model_states)
it = range(duration // self._jitted_sample_duration)
if self._show_progress:
it = tqdm(it, ncols=80, disable=None, unit="chunk")
for dur_i in it:
# FIXME: split for entire duration instead of each loop iteration
keys = self._split_prng_key(self._jitted_sample_duration)
(
new_epoch,
new_ks,
new_ms,
position_chain,
infos,
ksc,
quants,
) = self._sample_many_jitted(
keys, self.current_epoch, self._kernel_states, self._model_states
)
self._epoch = new_epoch
self._kernel_states = new_ks
self._model_states = new_ms
self._position_chain.append(position_chain)
self._transition_info_chain.append(infos)
if self._store_kernel_states:
self._kernel_state_chain.append(ksc)
if self._quantity_generators:
self._quantities_chain.append(quants)