"""
MCMC chains
This module is experimental. Expect API changes.
"""
from typing import Callable, Generic, Protocol, Sequence, TypeVar
import jax
import numpy as np
from liesel.option import Option
from ..docs import usedocs
from .epoch import EpochConfig
from .pytree import concatenate_leaves, slice_leaves
from .types import PyTree
TPyTree = TypeVar("TPyTree", bound=PyTree)
__docformat__ = "numpy"
[docs]class Chain(Protocol[TPyTree]):
"""
A ``Chain`` stores multiple chucks of pytrees and concatenates them along a
time axis.
The ``Chain`` always assume multiple independent chains that are indexed via
the first axis. The second dimension represents the time. Consequently, the
leaves in the pytree must have a dimension of two (i.e., [chain, time,
...]).
A ``Chain`` operates on the assumption that all chunks are pytrees with the
same structure. However, the time dimension is allowed to vary in size.
"""
[docs] def append(self, chunk: TPyTree) -> None:
"""Appends a chunk to the chain."""
[docs] def get(self) -> Option[TPyTree]:
"""
Returns all chunks combined into one pytree.
The option is none if no samples are in the chain.
"""
[docs]class EpochChain(Chain[TPyTree]):
"""
An ``EpochChain`` is a :class:`.Chain` with an associated :class:`.EpochConfig`.
The implementation must implement thinning. That is,
if epoch.thinning > 1 and enabled in contructor, the chain must
safe only every epoch.thinning element
"""
@property
def epoch(self) -> EpochConfig:
"""Returns the associated :class:`.EpochConfig`."""
# implementations
[docs]@usedocs(Chain)
class ListChain(Generic[TPyTree]):
"""Implements the :class:`.Chain` protocol with a list as storage."""
def __init__(self):
self._chunks_list: list[TPyTree] = []
[docs] def append(self, chunk: TPyTree) -> None:
return self._chunks_list.append(chunk)
def _concatenate(self) -> None:
combined = concatenate_leaves(self._chunks_list, 1)
if combined is not None:
self._chunks_list = [combined]
[docs] def get(self) -> Option[TPyTree]:
if len(self._chunks_list) == 0:
return Option(None)
else:
self._concatenate()
return Option(self._chunks_list[0])
[docs]@usedocs(EpochChain)
class ListEpochChain(ListChain[TPyTree]):
"""Implements the :class:`.EpochChain` protocol with a list as storage."""
def __init__(self, epoch: EpochConfig, apply_thinning: bool = False):
super().__init__()
self._epoch = epoch
self._apply_thinning = apply_thinning
self._states_counter = 1
@property
def epoch(self) -> EpochConfig:
return self._epoch
[docs] def append(self, chunk: TPyTree) -> None:
"""Applies thinning and appends a chunk to the chain."""
if self._apply_thinning and self.epoch.thinning > 1:
th = self._epoch.thinning
size = jax.tree_util.tree_leaves(chunk)[0].shape[1]
idx = np.arange(size)[(self._states_counter + np.arange(size)) % th == 0]
self._states_counter += size
if len(idx) > 0:
chunk = slice_leaves(chunk, np.s_[:, idx, ...])
return super().append(chunk)
else:
return super().append(chunk)
[docs]class EpochChainManager(Generic[TPyTree]):
"""
An ``EpochChainManager`` is a container for multiple epoch chains.
The chains can be concatenated over multiple epochs. Thinning defined in epochs
can be switched on or of with the constructor flag
"""
def __init__(self, apply_thinning: bool = False) -> None:
self._chains: list[ListEpochChain[TPyTree]] = []
self._apply_thinning = apply_thinning
@property
def current_epoch(self) -> EpochConfig:
"""Returns the current epoch."""
return self._chains[-1].epoch
[docs] def advance_epoch(self, epoch: EpochConfig) -> None:
"""Creates and appends a :class:`.ListEpochChain` for the given ``epoch``."""
new_chain: ListEpochChain[TPyTree] = ListEpochChain(epoch, self._apply_thinning)
self._chains.append(new_chain)
[docs] @usedocs(ListEpochChain.append)
def append(self, chunk: TPyTree) -> None:
self._chains[-1].append(chunk)
[docs] def get_epochs(self) -> Sequence[EpochConfig]:
"""Returns a list of all epochs."""
return [c.epoch for c in self._chains]
[docs] def get_specific_chain(self, epoch_number: int) -> ListEpochChain[TPyTree]:
"""Returns the chain for the given epoch number."""
return self._chains[epoch_number]
[docs] def get_current_chain(self) -> ListEpochChain[TPyTree]:
"""Returns the current chain."""
return self._chains[-1]
[docs] def get_current_epoch(self) -> EpochConfig:
"""Returns the current poch."""
return self._chains[-1].epoch
[docs] def combine(self, epoch_numbers: Sequence[int]) -> Option[TPyTree]:
"""
Combines the given epochs and returns all chunks combined into one pytree.
The option is none if no samples are in the chain.
"""
chain: ListChain[TPyTree] = ListChain()
for num in epoch_numbers:
epoch_chain = self._chains[num]
chunk = epoch_chain.get()
if chunk.is_some():
chain.append(chunk.unwrap())
return chain.get()
[docs] def combine_all(self) -> Option[TPyTree]:
"""
Combines the all epochs and returns all chunks combined into one pytree.
The option is none if no samples are in the chain.
"""
chain: ListChain[TPyTree] = ListChain()
for epoch_chain in self._chains:
chunk = epoch_chain.get()
if chunk.is_some():
chain.append(chunk.unwrap())
return chain.get()
[docs] def combine_filtered(
self, predicate: Callable[[EpochConfig], bool]
) -> Option[TPyTree]:
"""
Combines the all epochs for which the predicate evaluates to ``True``.
Returns all chunks combined into one pytree.
The option is none if no samples are in the chain.
"""
chain: ListChain[TPyTree] = ListChain()
for epoch_chain in self._chains:
if predicate(epoch_chain.epoch):
chunk = epoch_chain.get()
if chunk.is_some():
chain.append(chunk.unwrap())
return chain.get()