List of additional position keys that should be tracked.
If a position key is tracked that means the correspond element of the model state will be saved and included in the
SamplingResultsand the posterior samples returned by
By default, only position keys associated with an MCMC kernel are tracked. You can easily add additional position keys by appending to this list.
For this example, we import
>>> import tensorflow_probability.substrates.jax.distributions as tfd
Consider the following simple model, in which we use the logarithm of the scale parameter in a normal distribution and take the exponential value for including the actual scale:
>>> log_scale = lsl.param(0.0, name="log_scale") >>> scale = lsl.Calc(jnp.exp, variance, _name="scale") >>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=scale) >>> y = lsl.obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y") >>> model = lsl.GraphBuilder().add(y).build_model()
Now we might want to set up an engine builder with a NUTS kernel for the parameter
>>> builder = gs.EngineBuilder(seed=1, num_chains=4) >>> builder.set_model(gs.LieselInterface(model)) >>> builder.add_kernel(gs.NUTSKernel(["log_scale"]))
By default, only the position key
"log_scale"is tracked and will be included in the results. Now, if you also want the value of
"scale"to be included, you can add it to the list of included position keys:
>>> builder.position_keys.append("scale") >>> builder.position_keys ['scale']
Beware however that including many intermediate position keys can lead to large results. In some cases it may be preferable to keep the tracked positions to a minimum and recompute the intermediate values from the posterior samples.