Source code for liesel.goose.gibbs

"""
Gibbs sampler.
"""

from collections.abc import Callable, Sequence
from typing import ClassVar

from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    ReprMixin,
    TransitionOutcome,
    TuningOutcome,
    WarmupOutcome,
)
from .types import Kernel, KernelState, KeyArray, ModelState, Position, TuningInfo

GibbsKernelState = KernelState
GibbsTransitionInfo = DefaultTransitionInfo
GibbsTuningInfo = DefaultTuningInfo


[docs] class GibbsKernel( ModelMixin, Kernel[GibbsKernelState, GibbsTransitionInfo, GibbsTuningInfo], ReprMixin, ): """ A Gibbs kernel implementing the :class:`.Kernel` protocol. Parameters ---------- position_keys Sequence of position keys (variable names) handled by this kernel. transition_fn Custom transition function that needs to be provided by the user. identifier A string acting as a unique identifier for this kernel. Examples -------- For this example, we import ``tensorflow_probability``, ``jax`` and ``jax.numpy`` as follows: >>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import jax >>> import jax.numpy as jnp First, we set up a minimal model: >>> mu = lsl.Var.new_param(0.0, name="mu") >>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y") >>> model = lsl.Model([y]) Now we initialize the EngineBuilder and set the desired number of warmup and posterior samples: >>> builder = gs.EngineBuilder(seed=1, num_chains=4) >>> builder.set_duration(warmup_duration=1000, posterior_duration=1000) Next, we set the model interface and initial values: >>> interface = gs.LieselInterface(model) >>> builder.set_model(interface) >>> builder.set_initial_values(model.state) We define a function to sample from the full conditional for the parameter ``"mu"``: >>> def sample_mu(prng_key, model_state): ... # extract relevant values from model state ... pos = interface.extract_position( ... position_keys=["y", "mu"], model_state=model_state ... ) ... # calculate relevant intermediate quantities ... n = len(pos["y"]) ... y_mean = pos["y"].mean() ... mu_new = (n * y_mean + pos["mu"]) / (n + 1) ... # draw new value from full conditional ... draw = mu_new + jax.random.normal(prng_key) ... # return key-value pair of variable name and new value ... return {"mu": draw} >>> builder.add_kernel(gs.GibbsKernel(["mu"], sample_mu)) Finally, we build the engine: >>> engine = builder.build() From here, you can continue with :meth:`~.goose.Engine.sample_all_epochs` to draw samples from your posterior distribution. See Also -------- :doc:`/tutorials/md/01d-gibbs-sampling` """ error_book: ClassVar[dict[int, str]] = {0: "no errors"} needs_history: ClassVar[bool] = False identifier: str = "" position_keys: tuple[str, ...] def __init__( self, position_keys: Sequence[str], transition_fn: Callable[[KeyArray, ModelState], Position], identifier: str = "", ): self._model = None self.position_keys = tuple(position_keys) self._transition_fn = transition_fn self.identifier = identifier
[docs] @classmethod def with_transition_fn( cls, transition_fn: Callable[[KeyArray, ModelState], Position] ) -> Callable[[Sequence[str], str], "GibbsKernel"]: """ Return a Gibbs kernel factory with a fixed transition function. This helper is useful when a Gibbs kernel should be configured through an :class:`.MCMCSpec`. The returned callable accepts the position keys and an optional identifier, matching the kernel factory interface expected by :class:`.LieselMCMC`. Parameters ---------- transition_fn Function implementing one Gibbs update. It receives a PRNG key and the current model state, and must return a position dictionary with updated values for the variables handled by the kernel. Returns ------- Callable A kernel factory that creates :class:`.GibbsKernel` instances for a given sequence of position keys. """ def gibbs_kernel_constructor( position_keys: Sequence[str], identifier: str = "", ) -> GibbsKernel: return GibbsKernel( position_keys=position_keys, transition_fn=transition_fn, identifier=identifier, ) return gibbs_kernel_constructor
[docs] def init_state(self, prng_key, model_state): """ Initializes an (empty) kernel state. """ return {}
[docs] def transition( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[KernelState, GibbsTransitionInfo]: """ Performs an MCMC transition. """ info = GibbsTransitionInfo( error_code=0, acceptance_prob=1.0, position_moved=1, ) position = self._transition_fn(prng_key, model_state) model_state = self.model.update_state(position, model_state) return TransitionOutcome(info, kernel_state, model_state)
[docs] def tune( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[KernelState, GibbsTuningInfo]: """ Currently does nothing. """ info = GibbsTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, ) -> KernelState: """ Currently does nothing. """ return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, ) -> KernelState: """ Currently does nothing. """ return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, tuning_history: TuningInfo | None, ) -> WarmupOutcome[KernelState]: """ Currently does nothing. """ return WarmupOutcome(error_code=0, kernel_state=kernel_state)