EngineBuilder.set_jitter_fns()#
- EngineBuilder.set_jitter_fns(jitter_fns)[source]#
Set the jittering functions.
A jittering function is a function that takes as input a key and a value, and applies a random jittering (noise) to the input value based on the given key. If no jitter function is provided a Warning will be raised and the values won’t be jittered.
- Parameters:
jitter_fns (
dict
[str
,Callable
[[Any
,Any
],Any
]] |None
) – A dictionary where a jittering function is assigned to each position key.
Examples
>>> import jax >>> import jax.numpy as jnp >>> import tensorflow_probability.substrates.jax.distributions as tfd
>>> key = jax.random.PRNGKey(42)
In this example, we show how to use the method
EngineBuilder.set_jitter_fns()
to apply jittering to the initial values of each chain.First, we sample 500 data points from a Normal Distrbution with mean 2.0 and standard deviation 1.0.
>>> n = 500 >>> true_mu = 2.0 >>> true_sigma = 1.0
>>> x_vec = tfd.Normal(loc=true_mu, scale=true_sigma).sample((n, ), key)
Then, we define the distribution we want to sample from, which is parametrized by a single parameter mu.
>>> mu = lsl.Var.new_param(1.0, name="mu") >>> x_dist = lsl.Dist(tfd.Normal, loc=true_mu, scale=true_sigma) >>> x = lsl.Var(x_vec, distribution=x_dist, name="x")
Now, we can create the model.
>>> model = lsl.Model([x])
Finally, we build the model with
EngineBuilder
. We will use 4 parallel chains and sample our varaible using aNUTSKernel
.>>> builder = gs.EngineBuilder(seed=1337, num_chains=4) >>> builder.set_model(gs.LieselInterface(model)) >>> builder.set_initial_values(model.state) >>> builder.add_kernel(gs.NUTSKernel(["mu"]))
A jitter function takes as input a key and value and applies random jittering to the given value using the key. In this case, we apply a uniform noise with a minimum value of -1.0 and maximum value of 1.0. Notice that the shape of val is (4, 1), where the first dimension corresponds to the number of chains.
>>> def jitter_fn(key, val): ... jitter = jax.random.uniform(key, val.shape, val.dtype, -1.0, 1.0) ... return val + jitter
The method takes as input a dictionary where a jittering function is assigned to each position key.
>>> builder.set_jitter_fns({"mu": jitter_fn})