Var.sample()#
- Var.sample(shape, seed, posterior_samples=None, fixed=(), newdata=None, dists=None)[source]#
Draws samples from the parental model for this variable.
- Parameters:
seed (
Array
) – The seed is split and distributed to the seed nodes of the model. Must be a jax RNG key array that satisfiesjnp.issubdtype(key.dtype, jax.dtypes.prng_key)
. Seejax.random
and https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details.posterior_samples (
dict
[str
,TypeAliasType
] |None
) – Dictionary of samples at which to evaluate predictions. All values of the dictionary are assumed to have two leading dimensions corresponding to(nchains, niteration)
. (default:None
)fixed (
Sequence
[str
]) – The names of the nodes or variables to be excluded from the simulation. By default, no nodes or variables are skipped. (default:()
)newdata (
dict
[str
,TypeAliasType
] |None
) – Dictionary of new data at which to produce samples. The keys should correspond to variable or node names in the model whose values should be set to the given values before sampling. IfNone
(default), the current variable values are used. (default:None
)dists (
dict
[str
,Dist
] |None
) – Can be used to provide a dictionary of variable names andDist
instances to use in sampling. IfNone
(default), samples are drawn for each variable using theirVar.dist_node
. (default:None
)
Notes
When compiling this function with
jax.jit
, the argumentsshape
,fixed
, anddists
must be static.