Source code for liesel.goose.gibbs

"""
Gibbs sampler.
"""

from collections.abc import Callable, Sequence
from typing import ClassVar

from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    TransitionOutcome,
    TuningOutcome,
    WarmupOutcome,
)
from .types import Kernel, KernelState, KeyArray, ModelState, Position, TuningInfo

GibbsKernelState = KernelState
GibbsTransitionInfo = DefaultTransitionInfo
GibbsTuningInfo = DefaultTuningInfo


[docs] class GibbsKernel( ModelMixin, Kernel[GibbsKernelState, GibbsTransitionInfo, GibbsTuningInfo] ): """ A Gibbs kernel implementing the :class:`.Kernel` protocol. """ error_book: ClassVar[dict[int, str]] = {0: "no errors"} needs_history: ClassVar[bool] = False identifier: str = "" position_keys: tuple[str, ...] def __init__( self, position_keys: Sequence[str], transition_fn: Callable[[KeyArray, ModelState], Position], ): self._model = None self.position_keys = tuple(position_keys) self._transition_fn = transition_fn
[docs] def init_state(self, prng_key, model_state): """ Initializes an (empty) kernel state. """ return {}
[docs] def transition( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[KernelState, GibbsTransitionInfo]: """ Performs an MCMC transition. """ info = GibbsTransitionInfo( error_code=0, acceptance_prob=1.0, position_moved=1, ) position = self._transition_fn(prng_key, model_state) model_state = self.model.update_state(position, model_state) return TransitionOutcome(info, kernel_state, model_state)
[docs] def tune( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[KernelState, GibbsTuningInfo]: """ Currently does nothing. """ info = GibbsTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, ) -> KernelState: """ Currently does nothing. """ return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, epoch: EpochState, ) -> KernelState: """ Currently does nothing. """ return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: KernelState, model_state: ModelState, tuning_history: TuningInfo | None, ) -> WarmupOutcome[KernelState]: """ Currently does nothing. """ return WarmupOutcome(error_code=0, kernel_state=kernel_state)