Model.sample()

Contents

Model.sample()#

Model.sample(shape, seed, posterior_samples=None, fixed=(), newdata=None, dists=None)[source]#

Draws samples from the model.

Parameters:
  • shape (Sequence[int]) – Sample shape.

  • seed (Array) – The seed is split and distributed to the seed nodes of the model. Must be a jax RNG key array that satisfies jnp.issubdtype(key.dtype, jax.dtypes.prng_key). See jax.random and https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details.

  • posterior_samples (dict[str, TypeAliasType] | None) – Dictionary of samples at which to evaluate predictions. All values of the dictionary are assumed to have two leading dimensions corresponding to (nchains, niteration). (default: None)

  • fixed (Sequence[str]) – The names of the nodes or variables to be excluded from the simulation. By default, no nodes or variables are skipped. (default: ())

  • newdata (dict[str, TypeAliasType] | None) – Dictionary of new data at which to produce samples. The keys should correspond to variable or node names in the model whose values should be set to the given values before sampling. If None (default), the current variable values are used. (default: None)

  • dists (dict[str, Dist] | None) – Can be used to provide a dictionary of variable names and Dist instances to use in sampling. If None (default), samples are drawn for each variable using their Var.dist_node. (default: None)

Notes

When compiling this function with jax.jit, the arguments shape, fixed, and dists must be static.

Return type:

dict[str, TypeAliasType]

Returns:

  • A dictionary of variable and node names and their sampled values. Includes

  • only sampled variables.