"""
Type aliases, type variables and protocols.
"""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, ClassVar, NewType, Protocol, TypeVar
if TYPE_CHECKING:
from .epoch import EpochState
from .kernel import (
DefaultTransitionInfo,
TransitionOutcome,
TuningOutcome,
WarmupOutcome,
)
# simple type aliases
PyTree = Any
Array = Any
ModelState = PyTree
Position = NewType("Position", dict[str, Array])
KernelState = PyTree
KeyArray = Any
[docs]class TuningInfo(Protocol):
"""Holds information about sampler tuning."""
error_code: int
time: int
[docs]class TransitionInfo(Protocol):
"""Holds information about MCMC transitions."""
error_code: int
"""
An error code defined in the error book of the kernel.
0 if no errors occurred during the transition.
"""
acceptance_prob: float
"""
The acceptance probability of a proposal of a Metropolis-Hastings kernel
or the average acceptance probability across a trajectory of a NUTS-type kernel.
99.0 if not used by the kernel.
"""
position_moved: int
"""
0 if the position did not move during the transition, 1 if it *did* move,
and 99 if unknown.
"""
[docs] @abstractmethod
def minimize(self) -> DefaultTransitionInfo:
raise NotImplementedError
TKernelState = TypeVar("TKernelState", bound=KernelState)
TTransitionInfo = TypeVar("TTransitionInfo", bound=TransitionInfo)
TTuningInfo = TypeVar("TTuningInfo", bound=TuningInfo)
[docs]class ModelInterface(Protocol):
"""
Defines a standardized way for Goose to communicate with a statistical model.
This means predominantly, to update the model state and to compute the
log-probability.
"""
[docs] @abstractmethod
def update_state(self, position: Position, model_state: ModelState) -> ModelState:
"""Updates the model state with the values in the position."""
raise NotImplementedError
[docs] @abstractmethod
def log_prob(self, model_state: ModelState) -> float:
"""Computes the unnormalized log-probability given the model state."""
raise NotImplementedError
[docs]class Kernel(Protocol[TKernelState, TTransitionInfo, TTuningInfo]):
"""Protocol for a transition kernel."""
error_book: ClassVar[dict[int, str]]
"""
Maps error codes to error messages.
We use -1 if an error code is required but the kernel is skipped.
"""
needs_history: ClassVar[bool] = False
"""Is set to true if the kernel expects the history for tuning."""
position_keys: tuple[str, ...]
"""Keys for which the kernel handles the transition."""
identifier: str = ""
"""
An identifier for the kernel object that is set by the `EngineBuilder`
if it is an empty string.
"""
[docs] @abstractmethod
def set_model(self, model: ModelInterface):
"""Sets the model interface."""
...
[docs] @abstractmethod
def has_model(self) -> bool:
"""Whether the model interface is set."""
...
[docs] @abstractmethod
def init_state(self, prng_key: KeyArray, model_state: ModelState) -> KernelState:
"""Creates the initial kernel state."""
raise NotImplementedError
[docs] @abstractmethod
def transition(
self,
prng_key: KeyArray,
kernel_state: TKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TransitionOutcome[TKernelState, TTransitionInfo]:
"""Handles one transition. Must be jittable."""
raise NotImplementedError
[docs] @abstractmethod
def tune(
self,
prng_key: KeyArray,
kernel_state: TKernelState,
model_state: ModelState,
epoch: EpochState,
history: Position | None,
) -> TuningOutcome[TKernelState, TTuningInfo]:
"""
The method can perform automatic tuning of the kernel and is called
after each adaptation epoch.
To tune the kernel, the method can return an altered kernel state.
Must be jittable.
Parameters
----------
prng_key
The key for JAX' pseudo-random number generator.
model_state
Current model state.
kernel_state
Current kernel state.
epoch
Current epoch state.
history
Holds the history of the position of the current epoch, i.e., that
is the position but each leave in the pytree is enhanced by one
dimension (axis = 0) which represents the time or MCMC iteration.
The value may be ``None`` if the class variable :py:attr:`~needs_histroy` is
``False``.
"""
raise NotImplementedError
[docs] @abstractmethod
def start_epoch(
self,
prng_key: KeyArray,
kernel_state: TKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TKernelState:
"""
Called at the beginning of an epoch. Must be jittable.
"""
raise NotImplementedError
[docs] @abstractmethod
def end_epoch(
self,
prng_key: KeyArray,
kernel_state: TKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TKernelState:
"""
Called at the end of an epoch. Must be jittable.
"""
raise NotImplementedError
[docs] @abstractmethod
def end_warmup(
self,
prng_key: KeyArray,
kernel_state: TKernelState,
model_state: ModelState,
tuning_history: TTuningInfo | None,
) -> WarmupOutcome[TKernelState]:
"""
Asks the kernel to inspect the warmup history and react to it.
This method is executed once the first non-warmup epoch is encountered and
before :meth:`.start_epoch` is called.
``tuning_history`` is ``None`` if no tuning has happened prior to the first
non-warmup epoch. Otherwise ``tuning_history`` has the same structure as
returned from :meth:`.tune` but each leaf has an additional dimension. The first
dimension refers to the n-the tuning.
Must be jittable.
The signature is likely to change.
"""
raise NotImplementedError
[docs]class GeneratedQuantity(Protocol):
"""
Protocol representing the data structure for quantities generated via
implmentations of :class:`.QuantityGenerator`.
Concrete implementations should add additional attributes.
The attribute ``error_code`` is reserved to store integers that map to error
messages via the error book provided in the implmentation of
:class:`.QuantityGenerator`.
"""
error_code: int
TGeneratedQuantity = TypeVar(
"TGeneratedQuantity", bound=GeneratedQuantity, covariant=True
)
[docs]class QuantityGenerator(Protocol[TGeneratedQuantity]):
"""
Protocol for a class that calculates a quantity based on the model
state and a random seed.
"""
error_book: ClassVar[dict[int, str]]
identifier: str
[docs] @abstractmethod
def set_model(self, model: ModelInterface):
"""
Sets a model.
Parameters
----------
model
The model to be set.
"""
raise NotImplementedError
[docs] @abstractmethod
def has_model(self) -> bool:
"""*True*, if a model has been set with :meth:`.set_model`."""
raise NotImplementedError
[docs] def generate(
self,
prng_key: KeyArray,
model_state: ModelState,
epoch: EpochState,
) -> TGeneratedQuantity:
"""
Generates a new quantity based on the model and PRNG state.
Parameters
----------
prng_key
The key for JAX' pseudo-random number generator.
model_state
Current model state.
epoch
Current epoch state.
"""