Model.simulate()

Contents

Model.simulate()#

Model.simulate(seed, skip=())[source]#

Updates the model state simulating from the probability distributions in the model using a provided random seed, optionally skipping specified nodes.

Parameters:
  • seed (Array) – The seed is split and distributed to the seed nodes of the model. Must be a jax RNG key array that satisfies jnp.issubdtype(key.dtype, jax.dtypes.prng_key). See jax.random and https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details.

  • skip (Iterable[str]) – The names of the nodes or variables to be excluded from the simulation. By default, no nodes or variables are skipped. (default: ())

Return type:

Model

Returns:

The model instance itself after updating its state with the simulated values.

Raises:

AttributeError – If the value of the Dist.at node of a distribution node cannot be set.

Notes

The simulation is based on the shapes of the current values of the Dist.at nodes of the distribution nodes. If the Dist.at node of a distribution node is a VarValue node, the value of its input is updated.