Source code for liesel.goose.builder

"""
MCMC engine builder

The purpose of the engine builder is to provide a simple API to gradually assemble
the components needed by the MCMC engine. The builder is responsible of returning an
engine in a well-defined state. Furthermore, the builder can return different engine
implementations.
"""

import logging
import math
from collections.abc import Iterable
from typing import cast

import jax
import jax.numpy as jnp

from liesel.option import Option

from .engine import Engine
from .epoch import EpochConfig, EpochManager
from .kernel_sequence import KernelSequence
from .pytree import stack_leaves
from .types import (
    JitterFunctions,
    Kernel,
    KeyArray,
    ModelInterface,
    ModelState,
    QuantityGenerator,
)
from .warmup import stan_epochs

logger = logging.getLogger(__name__)


def _find_duplicate(xs: list[str]) -> Option[str]:
    """Checks if a list of strings contains any duplicates."""
    set_ = set()

    for x in xs:
        if x in set_:
            return Option(x)
        else:
            set_.add(x)

    return Option(None)


[docs]class EngineBuilder: """ The :class:`.EngineBuilder` is used to construct an MCMC Engine. Currently, the :class:`.EngineBuilder` builds an object of the class :class:`.Engine`. By default, every position key associated with an MCMC kernel is tracked. This behavior can be adjusted with the fields :attr:`.positions_included` and :attr:`.positions_excluded`. Parameters ---------- seed Used to initialize the PRNG for the building process. The PRNG state num_chains The number of chains to be used. """ def __init__(self, seed: int, num_chains: int): keys = jax.random.split(jax.random.PRNGKey(seed), 3) self._prng_key: KeyArray = keys[0] self._engine_key: KeyArray = keys[1] self._jitter_key: KeyArray = keys[2] self._num_chains: int = num_chains self._kernels: list[Kernel] = [] self._quantity_generators: list[QuantityGenerator] = [] self._model_state: Option[ModelState] = Option(None) self._model: Option[ModelInterface] = Option(None) self._jitter_fns: Option[JitterFunctions] = Option(None) # public fields, only simple states self.store_kernel_states: bool = False self.minimize_transition_infos: bool = False self.show_progress: bool = True self.positions_included: list[str] = [] """List of additional position keys that should be tracked.""" self.positions_excluded: list[str] = [] """List of position keys that should not be tracked. Excluded keys override additional keys."""
[docs] def set_engine_seed(self, seed: int | KeyArray): """Sets a seed used to initialize the MCMC engine.""" if jnp.isscalar(seed): seed_int: int = cast(int, seed) self._engine_key = jax.random.PRNGKey(seed_int) else: seed_keyarray = cast(KeyArray, seed) self._engine_key = seed_keyarray
@property def engine_seed(self) -> KeyArray: """The seed for the engine's pseudo-random number generation.""" return self._engine_key
[docs] def add_kernel(self, kernel: Kernel): """Adds a :class:`.Kernel`.""" self._kernels.append(kernel)
@property def kernels(self) -> tuple[Kernel, ...]: """Tuple of all Kernels that are present in the builder.""" return tuple(self._kernels)
[docs] def add_quantity_generator(self, generator: QuantityGenerator): """Adds a :class:`.QuantityGenerator`.""" self._quantity_generators.append(generator)
@property def quantity_generators(self) -> tuple[QuantityGenerator, ...]: """Tuple of all quantity generators present in the builder.""" return tuple(self._quantity_generators)
[docs] def set_initial_values(self, model_state: ModelState, multiple_chains=False): """ Sets the initial model state. If :attr:`.multiple_chains` is true the :attr:`.model_state` will be used as is; otherwise :attr:`.model_state` will be used as the initial values for each chain. Note that if :attr:`.multiple_chains` is true, the first axis of each leaf of :attr:`.model_state` refers to the chain. """ if not multiple_chains: model_states = stack_leaves(model_state for _ in range(self._num_chains)) self._model_state = Option(model_states)
[docs] def set_jitter_fns(self, jitter_fns: JitterFunctions | None): """Set the jittering functions.""" self._jitter_fns = Option(jitter_fns)
@property def jitter_fns(self) -> Option[JitterFunctions]: """Jittering functions.""" return self._jitter_fns @property def model_state(self) -> Option[ModelState]: """Model state.""" return self._model_state
[docs] def set_epochs(self, epochs: Iterable[EpochConfig]): """Sets epochs.""" self._epochs = EpochManager(epochs)
[docs] def set_duration( self, warmup_duration: int, posterior_duration: int, term_duration: int = 50, thinning_posterior: int = 1, thinning_warmup: int = 1, ): """ Sets epochs using the :func:`.stan_epochs` function. Note that :attr:`.term_duration` needs to be long enough that tuning algorithms like dual averaging can converge. """ epochs = stan_epochs( warmup_duration, posterior_duration, term_duration=term_duration, thinning_posterior=thinning_posterior, thinning_warmup=thinning_warmup, ) self._epochs = EpochManager(epochs)
@property def epochs(self) -> tuple[EpochConfig, ...]: """Tuple of epoch configurations.""" return tuple(self._epochs._configs)
[docs] def set_model(self, model: ModelInterface): """Sets the model interface for all kernels and quantity generators.""" self._model = Option(model)
[docs] def build(self) -> Engine: """Builds the MCMC engine with the provided setup.""" # build list of position keys pos_keys: list[str] = [] for ker in self._kernels: pos_keys.extend(ker.position_keys) dupl = _find_duplicate(pos_keys) if dupl.is_some(): raise RuntimeError( f"The position key {dupl.unwrap()} is claimed by multiple kernels" ) # find good jittable number epochs = self._epochs._configs # FIXME: use of private field durations = [e.duration for e in epochs[1:]] jit_duration = math.gcd(*durations) # seeds seeds = self._engine_key if seeds.shape == (2,): # no multi-chain key seeds = jax.random.split(seeds, self._num_chains) if seeds.shape != (self._num_chains, 2): raise RuntimeError( f"MCMC seed has the wrong dimensions {seeds}. " f"Expected is {(self._num_chains, 2)}" ) # check for duplicated identifiers in self._quantity_generators idents = [] for qg in self._quantity_generators: idents.append(qg.identifier) dupl = _find_duplicate(idents) if dupl.is_some(): raise RuntimeError( f"The identifier {dupl.unwrap()} is used by multiple " "quantity generators" ) # set model interface for all kernels and quantity generators model = self._model.expect("Model interface must be set") for ker in self.kernels: if not ker.has_model(): ker.set_model(model) for qg in self.quantity_generators: if not qg.has_model(): qg.set_model(model) # assign identifiers to kernels for idx, ker in enumerate(self.kernels): if not ker.identifier: ker.identifier = f"kernel_{idx:02d}" # set initial values model_states = self._model_state.expect("Model state must be set") if self._jitter_fns.is_some(): jitter_fns = self._jitter_fns.unwrap() jitter_fns_pos_keys = list(jitter_fns.keys()) missing_keys = [] for pos_key in pos_keys: if pos_key not in jitter_fns_pos_keys: missing_keys.append(pos_key) if missing_keys: pretty_keys = str(missing_keys)[1:-1] logger.warning( f"No jitter functions provided for position keys {pretty_keys}. " "The initial values for these keys won't be jittered" ) jitter_keys = jax.random.split(self._jitter_key, len(jitter_fns_pos_keys)) current_position = model.extract_position(jitter_fns_pos_keys, model_states) jittered_position = {} for i, pos_key in enumerate(jitter_fns_pos_keys): jittered_position[pos_key] = jax.vmap(jitter_fns[pos_key])( jax.random.split(jitter_keys[i], self._num_chains), current_position[pos_key], ) model_states = jax.vmap(model.update_state)(jittered_position, model_states) else: logger.warning( "No jitter functions provided. The initial values won't be jittered" ) # extending position keys pos_keys.extend(self.positions_included) pos_keys = [key for key in pos_keys if key not in self.positions_excluded] return Engine( seeds=seeds, model_states=model_states, kernel_sequence=KernelSequence(self.kernels), epoch_configs=epochs, jitted_sample_duration=jit_duration, model=model, position_keys=pos_keys, minimize_transition_infos=self.minimize_transition_infos, store_kernel_states=self.store_kernel_states, quantity_generators=self.quantity_generators, show_progress=self.show_progress, )