EngineBuilder.set_jitter_fns()

EngineBuilder.set_jitter_fns()#

EngineBuilder.set_jitter_fns(jitter_fns)[source]#

Set the jittering functions.

A jittering function is a function that takes as input a key and a value, and applies a random jittering (noise) to the input value based on the given key. If no jitter function is provided a Warning will be raised and the values won’t be jittered.

Parameters:

jitter_fns (dict[str, Callable[[Any, Any], Any]] | None) – A dictionary where a jittering function is assigned to each position key.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> import tensorflow_probability.substrates.jax.distributions as tfd
>>> key = jax.random.PRNGKey(42)

In this example, we show how to use the method EngineBuilder.set_jitter_fns() to apply jittering to the initial values of each chain.

First, we sample 500 data points from a Normal Distrbution with mean 2.0 and standard deviation 1.0.

>>> n = 500
>>> true_mu = 2.0
>>> true_sigma = 1.0
>>> x_vec = tfd.Normal(loc=true_mu, scale=true_sigma).sample((n, ), key)

Then, we define the distribution we want to sample from, which is parametrized by a single parameter mu.

>>> mu = lsl.param(1.0, name="mu")
>>> x_dist = lsl.Dist(tfd.Normal, loc=true_mu, scale=true_sigma)
>>> x = lsl.Var(x_vec, distribution=x_dist, name="x")

Now, we can create the model with GraphBuilder.

>>> gb = lsl.GraphBuilder().add(x)
>>> model = gb.build_model()

Finally, we build the model with EngineBuilder. We will use 4 parallel chains and sample our varaible using a NUTSKernel.

>>> builder = gs.EngineBuilder(seed=1337, num_chains=4)
>>> builder.set_model(gs.LieselInterface(model))
>>> builder.set_initial_values(model.state)
>>> builder.add_kernel(gs.NUTSKernel(["mu"]))

A jitter function takes as input a key and value and applies random jittering to the given value using the key. In this case, we apply a uniform noise with a minimum value of -1.0 and maximum value of 1.0. Notice that the shape of val is (4, 1), where the first dimension corresponds to the number of chains.

>>> def jitter_fn(key, val):
...     jitter = jax.random.uniform(key, val.shape, val.dtype, -1.0, 1.0)
...     return val + jitter

The method takes as input a dictionary where a jittering function is assigned to each position key.

>>> builder.set_jitter_fns({"mu": jitter_fn})