Gibbs Sampling#
This tutorial extends the linear regression tutorial. Here, we show how to sample model parameters using a Gibbs kernel.
As this tutorial is a continuation of the previous tutorials, we will use the same model and data assumed there.
Data and imports#
import jax
import jax.numpy as jnp
import numpy as np
# We use distributions and bijectors from tensorflow probability
import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb
import liesel.goose as gs
import liesel.model as lsl
import matplotlib.pyplot as plt
# Generate data
rng = np.random.default_rng(42)
# sample size and true parameters
n = 500
true_beta = np.array([1.0, 2.0])
true_sigma = 1.0
# data-generating process
x0 = rng.uniform(size=n)
X_mat = np.column_stack([np.ones(n), x0])
eps = rng.normal(scale=true_sigma, size=n)
y_vec = X_mat @ true_beta + eps
# define beta
beta_prior = lsl.Dist(tfd.Normal, loc=0.0, scale=100.0)
beta = lsl.Var.new_param(
value=jnp.array([0.0, 0.0]), distribution=beta_prior, name="beta"
)
# define the variance and the scale
a = lsl.Var.new_param(0.01, name="a")
b = lsl.Var.new_param(0.01, name="b")
sigma_sq_prior = lsl.Dist(tfd.InverseGamma, concentration=a, scale=b)
sigma_sq = lsl.Var.new_param(value=1.0, distribution=sigma_sq_prior, name="sigma_sq")
# Define sigma as a transformation of sigma_sq for the likelihood
sigma = lsl.Var.new_calc(jnp.sqrt, sigma_sq, name="sigma")
# calculator-setup
X = lsl.Var.new_obs(X_mat, name="X")
mu = lsl.Var.new_calc(jnp.dot, X, beta, name="mu")
# Build response
y_dist = lsl.Dist(tfd.Normal, loc=mu, scale=sigma)
y = lsl.Var.new_obs(y_vec, distribution=y_dist, name="y")
# Plot model
model = lsl.Model([y])
lsl.plot_vars(model)

MCMC inference#
Using a Gibbs kernel#
This time we want to sample the previously fixed sigma_sq with a Gibbs
sampler. Using a Gibbs kernel is a bit more complicated, because Goose
doesn’t automatically derive the full conditional from the model graph.
Hence, the user needs to provide a function to sample from the full
conditional. The function needs to accept a PRNG state and a model state
as arguments, and it needs to return a dictionary with the node name as
the key and the new node value as the value. We could also update
multiple parameters with one Gibbs kernel if we returned a dictionary of
length two or more. To retrieve the relevant values of our nodes from
the model_state, we use the method
extract_position() of the
LieselInterface.
def draw_sigma_sq(prng_key, model_state):
# extract relevant values from model state
pos = interface.extract_position(
position_keys=["y", "mu", "sigma_sq", "a", "b"],
model_state=model_state
)
# calculate relevant intermediate quantities
n = len(pos["y"])
resid = pos["y"] - pos["mu"]
a_gibbs = pos["a"] + n / 2
b_gibbs = pos["b"] + jnp.sum(resid**2) / 2
# draw new value from full conditional
draw = b_gibbs / jax.random.gamma(prng_key, a_gibbs)
# return key-value pair of variable name and new value
return {"sigma_sq": draw}
After constructing the Gibbs sampler, we can build our engine.
interface = gs.LieselInterface(model)
builder = gs.EngineBuilder(seed=1338, num_chains=4)
builder.set_model(gs.LieselInterface(model))
builder.set_initial_values(model.state)
builder.add_kernel(gs.NUTSKernel(["beta"]))
builder.add_kernel(gs.GibbsKernel(["sigma_sq"], draw_sigma_sq)) # add gibbs sampler
builder.set_duration(warmup_duration=1000, posterior_duration=1000)
engine = builder.build()
engine.sample_all_epochs()
0%| | 0/3 [00:00<?, ?chunk/s]
33%|############## | 1/3 [00:01<00:03, 1.60s/chunk]
100%|##########################################| 3/3 [00:01<00:00, 1.87chunk/s]
0%| | 0/1 [00:00<?, ?chunk/s]
100%|########################################| 1/1 [00:00<00:00, 2066.16chunk/s]
0%| | 0/2 [00:00<?, ?chunk/s]
100%|########################################| 2/2 [00:00<00:00, 3623.59chunk/s]
0%| | 0/4 [00:00<?, ?chunk/s]
100%|########################################| 4/4 [00:00<00:00, 3964.37chunk/s]
0%| | 0/8 [00:00<?, ?chunk/s]
100%|########################################| 8/8 [00:00<00:00, 1318.19chunk/s]
0%| | 0/20 [00:00<?, ?chunk/s]
100%|#######################################| 20/20 [00:00<00:00, 379.48chunk/s]
0%| | 0/2 [00:00<?, ?chunk/s]
100%|########################################| 2/2 [00:00<00:00, 3546.98chunk/s]
0%| | 0/40 [00:00<?, ?chunk/s]
82%|################################1 | 33/40 [00:00<00:00, 317.62chunk/s]
100%|#######################################| 40/40 [00:00<00:00, 304.37chunk/s]
Finally, we can take a look at our results
results = engine.get_results()
summary = gs.Summary(results)
summary
Parameter summary:
| kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
|---|---|---|---|---|---|---|---|---|---|---|---|
| parameter | index | ||||||||||
| beta | (0,) | kernel_00 | 0.984 | 0.091 | 0.837 | 0.984 | 1.135 | 4000 | 1101.340 | 1252.702 | 1.005 |
| (1,) | kernel_00 | 1.911 | 0.161 | 1.649 | 1.912 | 2.171 | 4000 | 1130.029 | 1237.464 | 1.006 | |
| sigma_sq | () | kernel_01 | 1.044 | 0.067 | 0.939 | 1.040 | 1.161 | 4000 | 3859.015 | 3732.136 | 1.000 |
Error summary:
| count | sample_size | sample_size_total | relative | ||||
|---|---|---|---|---|---|---|---|
| kernel | error_code | error_msg | phase | ||||
| kernel_00 | 1 | divergent transition | warmup | 62 | 4000 | 4000 | 0.016 |
| posterior | 0 | 4000 | 4000 | 0.000 |
And plot these
g = gs.plot_trace(results)

gs.plot_param(results, param="sigma_sq", param_index=0)

With that we end the tutorial on Gibbs sampling.