Source code for liesel.model.goose

"""
Goose model interface.
"""

from __future__ import annotations

from collections.abc import Iterable

from liesel.goose.types import ModelState, Position

from .model import Model


[docs]class GooseModel: """ A :class:`.ModelInterface` for a Liesel :class:`.Model`. Parameters ---------- model A Liesel :class:`.Model`. """ def __init__(self, model: Model): self._model = model._copy_computational_model()
[docs] def extract_position( self, position_keys: Iterable[str], model_state: ModelState ) -> Position: """ Extracts a position from a model state. Parameters ---------- position_keys An iterable of variable or node names. model_state A dictionary of node names and their corresponding :class:`.NodeState`. """ position = {} for key in position_keys: try: position[key] = model_state[key].value except KeyError: node_key = self._model.vars[key].value_node.name position[key] = model_state[node_key].value return Position(position)
[docs] def update_state(self, position: Position, model_state: ModelState) -> ModelState: """ Updates and returns a model state given a position. Parameters ---------- position A dictionary of variable or node names and values. model_state A dictionary of node names and their corresponding :class:`.NodeState`. Warnings -------- The ``model_state`` must be up-to-date, i.e. it must *not* contain any outdated nodes. Updates can only be triggered through new variable or node values in the ``position``. If you supply a ``model_state`` with outdated nodes, these nodes and their outputs will not be updated. """ # sets all outdated flags in the model state to false # this is required to make the function jittable self._model.state = model_state for node in self._model.nodes.values(): node._outdated = False for key, value in position.items(): try: self._model.nodes[key].value = value # type: ignore # data node except KeyError: self._model.vars[key].value = value self._model.update() return self._model.state
[docs] def log_prob(self, model_state: ModelState) -> float: """ Returns the log-probability from a model state. Parameters ---------- model_state A dictionary of node names and their corresponding :class:`.NodeState`. """ return model_state["_model_log_prob"].value