Source code for liesel.model.goose
"""
Goose model interface.
"""
from __future__ import annotations
import warnings
from collections.abc import Iterable, Sequence
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
from ..goose.gibbs import GibbsKernel
from ..goose.types import ModelState, Position
from .model import Model
[docs]
class GooseModel:
"""
A :class:`.ModelInterface` for a Liesel :class:`.Model`.
This is an alias for :class:`~.goose.LieselInterface` provided for backwards
compatibility.
.. deprecated:: v0.2.6
Use :class:`~.goose.LieselInterface` instead. This alias will be removed in
v0.4.0.
Parameters
----------
model
A Liesel :class:`.Model`.
"""
def __init__(self, model: Model):
self._model = model._copy_computational_model()
warnings.warn(
"lsl.GooseModel is deprecated. Use gs.LieselInterface instead."
"This alias will be removed in v0.4.0.",
FutureWarning,
)
[docs]
def update_state(self, position: Position, model_state: ModelState) -> ModelState:
"""
Updates and returns a model state given a position.
Parameters
----------
position
A dictionary of variable or node names and values.
model_state
A dictionary of node names and their corresponding :class:`.NodeState`.
Warnings
--------
The ``model_state`` must be up-to-date, i.e. it must *not* contain any outdated
nodes. Updates can only be triggered through new variable or node values in the
``position``. If you supply a ``model_state`` with outdated nodes, these nodes
and their outputs will not be updated.
"""
# sets all outdated flags in the model state to false
# this is required to make the function jittable
self._model.state = model_state
for node in self._model.nodes.values():
node._outdated = False
for key, value in position.items():
try:
self._model.nodes[key].value = value # type: ignore # data node
except KeyError:
self._model.vars[key].value = value
self._model.update()
return self._model.state
[docs]
def log_prob(self, model_state: ModelState) -> float:
"""
Returns the log-probability from a model state.
Parameters
----------
model_state
A dictionary of node names and their corresponding :class:`.NodeState`.
"""
return model_state["_model_log_prob"].value
[docs]
def finite_discrete_gibbs_kernel(
name: str, model: Model, outcomes: Sequence | None = None
) -> GibbsKernel:
"""
Creates a Gibbs kernel for a parameter with a finite discrete (categorical) prior.
The prior distribution of the variable to sample must be a categorical distribution,
usually implemented via :class:`tfd.FiniteDiscrete`.
This kernel evaluates the full conditional log probability of the model for each
possible value of the variable to sample. It then draws a new value for the variable
from the categorical distribution defined by the full conditional log probabilities.
Parameters
----------
name
The name of the variable to sample.
model
The model to sample from.
outcomes
The possible outcomes of the variable to sample. If ``outcomes=None``, the \
possible outcomes are extracted from the prior distribution of the variable \
to sample. Note however, that this only works for some prior distributions. \
If the possible outcomes cannot be extracted from the prior distribution, \
you must specify them manually via this argument.
Examples
--------
In the following example, we create a categorical Gibbs kernel for a variable with
three possible values. The prior distribution of the variable is a finite discrete
(categorical) distribution with the probabilities ``[0.1, 0.2, 0.7]``.
You can then use the kernel to sample from the model:
>>> import tensorflow_probability.substrates.jax.distributions as tfd
>>> values = [0.0, 1.0, 2.0]
>>> prior_probs = [0.1, 0.2, 0.7]
>>> value_grid = lsl.Var(values, name="value_grid")
>>> prior = lsl.Dist(tfd.FiniteDiscrete, outcomes=value_grid, probs=prior_probs)
>>> categorical_var = lsl.Var(
... value=values[0],
... distribution=prior,
... name="categorical_var",
... )
>>> model = lsl.GraphBuilder().add(categorical_var).build_model()
>>> kernel = finite_discrete_gibbs_kernel("categorical_var", model)
>>> type(kernel)
<class 'liesel.goose.gibbs.GibbsKernel'>
Example for a variable with a Bernoulli prior distribution:
>>> prior = lsl.Dist(tfd.Bernoulli, probs=lsl.Data(0.7))
>>> dummy_var = lsl.Var(
... value=1,
... distribution=prior,
... name="dummy_var",
... )
>>> model = lsl.GraphBuilder().add(dummy_var).build_model()
>>> kernel = finite_discrete_gibbs_kernel("dummy_var", model, outcomes=[0, 1])
>>> type(kernel)
<class 'liesel.goose.gibbs.GibbsKernel'>
"""
if outcomes is not None:
outcomes = jnp.asarray(outcomes)
else:
dist = model.vars[name].dist_node.init_dist() # type: ignore
assert dist.batch_shape == ()
match dist:
case tfd.Bernoulli():
outcomes = jnp.array([0, 1], dtype=dist.dtype)
case tfd.FiniteDiscrete():
outcomes = dist.outcomes
case _:
raise ValueError(
"Cannot extract outcomes from the distribution of variable "
f"'{name}'. Please provide the argument 'outcomes'."
)
model = model._copy_computational_model()
model.auto_update = False
def transition_fn(prng_key, model_state):
model.state = model_state
for node in model.nodes.values():
node._outdated = False
def conditional_log_prob_fn(value):
"""
Evaluates the full conditional log probability of the model
given the input value.
"""
model.vars[name].value = value
model.update("_model_log_prob")
return model.log_prob
conditional_log_probs = jax.vmap(conditional_log_prob_fn)(outcomes)
draw_index = jax.random.categorical(prng_key, logits=conditional_log_probs)
draw = outcomes[draw_index]
return {name: draw}
return GibbsKernel([name], transition_fn)