Source code for liesel.goose.models
"""
Model interfaces.
"""
import copy
from collections.abc import Callable, Sequence
from ..docs import usedocs
from .types import ModelInterface, ModelState, Position
LogProbFunction = Callable[[ModelState], float]
[docs]@usedocs(ModelInterface)
class DictModel:
"""
A model interface for a model state represented by a ``dict[str, Array]`` and a
corresponding log-probability function.
"""
def __init__(self, log_prob_fn: LogProbFunction):
self._log_prob_fn = log_prob_fn
[docs] def update_state(self, position: Position, model_state: ModelState) -> ModelState:
return model_state | position
[docs] def log_prob(self, model_state: ModelState) -> float:
return self._log_prob_fn(model_state)
[docs]@usedocs(ModelInterface)
class DataClassModel:
"""
A model interface for a model state represented by a :obj:`~dataclasses.dataclass`
and a corresponding log-probability function.
"""
def __init__(self, log_prob_fn: LogProbFunction):
self._log_prob_fn = log_prob_fn
[docs] def log_prob(self, model_state: ModelState) -> float:
return self._log_prob_fn(model_state)
[docs] def update_state(self, position: Position, model_state: ModelState) -> ModelState:
new_state = copy.copy(model_state) # don't change the input
for key, value in position.items():
if hasattr(new_state, key):
setattr(new_state, key, value)
else:
raise RuntimeError(
f"ModelState {model_state!r} does not have field with name {key}"
)
return new_state