Source code for liesel.goose.builder

"""
MCMC engine builder

The purpose of the engine builder is to provide a simple API to gradually assemble
the components needed by the MCMC engine. The builder is responsible of returning an
engine in a well-defined state. Furthermore, the builder can return different engine
implementations.
"""

import math
from collections.abc import Iterable
from typing import cast

import jax
import jax.numpy as jnp

from liesel.option import Option

from .engine import Engine
from .epoch import EpochConfig, EpochManager
from .kernel_sequence import KernelSequence
from .pytree import stack_leaves
from .types import Kernel, KeyArray, ModelInterface, ModelState, QuantityGenerator
from .warmup import stan_epochs


def _find_duplicate(xs: list[str]) -> Option[str]:
    """Checks if a list of strings contains any duplicates."""
    set_ = set()

    for x in xs:
        if x in set_:
            return Option(x)
        else:
            set_.add(x)

    return Option(None)


[docs]class EngineBuilder: """ The :class:`.EngineBuilder` is used to construct an MCMC Engine. Currently, the :class:`.EngineBuilder` builds an object of the class :class:`.Engine`. By default, every position key associated with an MCMC kernel is tracked. This behavior can be adjusted with the fields :attr:`.positions_included` and :attr:`.positions_excluded`. Parameters ---------- seed Used to initialize the PRNG for the building process. The PRNG state num_chains The number of chains to be used. """ def __init__(self, seed: int, num_chains: int): keys = jax.random.split(jax.random.PRNGKey(seed)) self._prng_key: KeyArray = keys[0] self._engine_key: KeyArray = keys[1] self._num_chains: int = num_chains self._kernels: list[Kernel] = [] self._quantity_generators: list[QuantityGenerator] = [] self._model_state: Option[ModelState] = Option(None) self._model: Option[ModelInterface] = Option(None) # public fields, only simple states self.store_kernel_states: bool = False self.minimize_transition_infos: bool = False self.show_progress: bool = True self.positions_included: list[str] = [] """List of additional position keys that should be tracked.""" self.positions_excluded: list[str] = [] """List of position keys that should not be tracked. Excluded keys override additional keys."""
[docs] def set_engine_seed(self, seed: int | KeyArray): """Sets a seed used to initialize the MCMC engine.""" if jnp.isscalar(seed): seed_int: int = cast(int, seed) self._engine_key = jax.random.PRNGKey(seed_int) else: seed_keyarray = cast(KeyArray, seed) self._engine_key = seed_keyarray
@property def engine_seed(self) -> KeyArray: """The seed for the engine's pseudo-random number generation.""" return self._engine_key
[docs] def add_kernel(self, kernel: Kernel): """Adds a :class:`.Kernel`.""" self._kernels.append(kernel)
@property def kernels(self) -> tuple[Kernel, ...]: """Tuple of all Kernels that are present in the builder.""" return tuple(self._kernels)
[docs] def add_quantity_generator(self, generator: QuantityGenerator): """Adds a :class:`.QuantityGenerator`.""" self._quantity_generators.append(generator)
@property def quantity_generators(self) -> tuple[QuantityGenerator, ...]: """Tuple of all quantity generators present in the builder.""" return tuple(self._quantity_generators)
[docs] def set_initial_values(self, model_state: ModelState, multiple_chains=False): """ Sets the initial model state. If :attr:`.multiple_chains` is true the :attr:`.model_state` will be used as is; otherwise :attr:`.model_state` will be used as the initial values for each chain. Note that if :attr:`.multiple_chains` is true, the first axis of each leaf of :attr:`.model_state` refers to the chain. """ if not multiple_chains: model_states = stack_leaves(model_state for _ in range(self._num_chains)) self._model_state = Option(model_states)
@property def model_state(self) -> Option[ModelState]: """Model state.""" return self._model_state
[docs] def set_epochs(self, epochs: Iterable[EpochConfig]): """Sets epochs.""" self._epochs = EpochManager(epochs)
[docs] def set_duration( self, warmup_duration: int, posterior_duration: int, term_duration: int = 50, thinning_posterior: int = 1, thinning_warmup: int = 1, ): """ Sets epochs using the :func:`.stan_epochs` function. Note that :attr:`.term_duration` needs to be long enough that tuning algorithms like dual averaging can converge. """ epochs = stan_epochs( warmup_duration, posterior_duration, term_duration=term_duration, thinning_posterior=thinning_posterior, thinning_warmup=thinning_warmup, ) self._epochs = EpochManager(epochs)
@property def epochs(self) -> tuple[EpochConfig, ...]: """Tuple of epoch configurations.""" return tuple(self._epochs._configs)
[docs] def set_model(self, model: ModelInterface): """Sets the model interface for all kernels and quantity generators.""" self._model = Option(model)
[docs] def build(self) -> Engine: """Builds the MCMC engine with the provided setup.""" # build list of position keys pos_keys: list[str] = [] for ker in self._kernels: pos_keys.extend(ker.position_keys) dupl = _find_duplicate(pos_keys) if dupl.is_some(): raise RuntimeError( f"The position key {dupl.unwrap()} is claimed by multiple kernels" ) pos_keys.extend(self.positions_included) pos_keys = [key for key in pos_keys if key not in self.positions_excluded] # find good jittable number epochs = self._epochs._configs # FIXME: use of private field durations = [e.duration for e in epochs[1:]] jit_duration = math.gcd(*durations) # seeds seeds = self._engine_key if seeds.shape == (2,): # no multi-chain key seeds = jax.random.split(seeds, self._num_chains) if seeds.shape != (self._num_chains, 2): raise RuntimeError( f"MCMC seed has the wrong dimensions {seeds}. " f"Expected is {(self._num_chains, 2)}" ) # check for duplicated identifiers in self._quantity_generators idents = [] for qg in self._quantity_generators: idents.append(qg.identifier) dupl = _find_duplicate(idents) if dupl.is_some(): raise RuntimeError( f"The identifier {dupl.unwrap()} is used by multiple " "quantity generators" ) # set model interface for all kernels and quantity generators model = self._model.expect("Model interface must be set") for ker in self.kernels: if not ker.has_model(): ker.set_model(model) for qg in self.quantity_generators: if not qg.has_model(): qg.set_model(model) # assign identifiers to kernels for idx, ker in enumerate(self.kernels): if not ker.identifier: ker.identifier = f"kernel_{idx:02d}" return Engine( seeds=seeds, model_states=self._model_state.expect("Model state must be set"), kernel_sequence=KernelSequence(self.kernels), epoch_configs=epochs, jitted_sample_duration=jit_duration, model=model, position_keys=pos_keys, minimize_transition_infos=self.minimize_transition_infos, store_kernel_states=self.store_kernel_states, quantity_generators=self.quantity_generators, show_progress=self.show_progress, )