Source code for liesel.model.goose
"""
Goose model interface.
"""
from __future__ import annotations
from collections.abc import Iterable
from liesel.goose.types import ModelState, Position
from .model import Model
[docs]class GooseModel:
"""
A :class:`.ModelInterface` for a Liesel :class:`.Model`.
Parameters
----------
model
A Liesel :class:`.Model`.
"""
def __init__(self, model: Model):
self._model = model._copy_computational_model()
[docs] def update_state(self, position: Position, model_state: ModelState) -> ModelState:
"""
Updates and returns a model state given a position.
Parameters
----------
position
A dictionary of variable or node names and values.
model_state
A dictionary of node names and their corresponding :class:`.NodeState`.
Warnings
--------
The ``model_state`` must be up-to-date, i.e. it must *not* contain any outdated
nodes. Updates can only be triggered through new variable or node values in the
``position``. If you supply a ``model_state`` with outdated nodes, these nodes
and their outputs will not be updated.
"""
# sets all outdated flags in the model state to false
# this is required to make the function jittable
self._model.state = model_state
for node in self._model.nodes.values():
node._outdated = False
for key, value in position.items():
try:
self._model.nodes[key].value = value # type: ignore # data node
except KeyError:
self._model.vars[key].value = value
self._model.update()
return self._model.state
[docs] def log_prob(self, model_state: ModelState) -> float:
"""
Returns the log-probability from a model state.
Parameters
----------
model_state
A dictionary of node names and their corresponding :class:`.NodeState`.
"""
return model_state["_model_log_prob"].value