Source code for liesel.goose.models
"""
Model interfaces.
"""
import copy
from collections.abc import Callable, Sequence
from ..docs import usedocs
from .types import ModelInterface, ModelState, Position
LogProbFunction = Callable[[ModelState], float]
[docs]@usedocs(ModelInterface)
class DictModel:
"""
A model interface for a model state represented by a ``dict[str, Array]`` and a
corresponding log-probability function.
"""
def __init__(self, log_prob_fn: LogProbFunction):
self._log_prob_fn = log_prob_fn
[docs] def extract_position(
self, position_keys: Sequence[str], model_state: ModelState
) -> Position:
return Position({key: model_state[key] for key in position_keys})
[docs] def update_state(self, position: Position, model_state: ModelState) -> ModelState:
return model_state | position
[docs]@usedocs(ModelInterface)
class DataClassModel:
"""
A model interface for a model state represented by a :obj:`~dataclasses.dataclass`
and a corresponding log-probability function.
"""
def __init__(self, log_prob_fn: LogProbFunction):
self._log_prob_fn = log_prob_fn
[docs] def extract_position(
self, position_keys: Sequence[str], model_state: ModelState
) -> Position:
return Position({key: getattr(model_state, key) for key in position_keys})
[docs] def update_state(self, position: Position, model_state: ModelState) -> ModelState:
new_state = copy.copy(model_state) # don't change the input
for key, value in position.items():
if hasattr(new_state, key):
setattr(new_state, key, value)
else:
raise RuntimeError(
f"ModelState {model_state!r} does not have field with name {key}"
)
return new_state