Gibbs Sampling#
This tutorial extends the linear regression tutorial. Here, we show how to sample model parameters using a Gibbs kernel.
As this tutorial is a continuation of the previous tutorials, we will use the same model and data assumed there.
Data and imports#
import jax
import jax.numpy as jnp
import numpy as np
# We use distributions and bijectors from tensorflow probability
import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb
import liesel.goose as gs
import liesel.model as lsl
import matplotlib.pyplot as plt
# Generate data
rng = np.random.default_rng(42)
# sample size and true parameters
n = 500
true_beta = np.array([1.0, 2.0])
true_sigma = 1.0
# data-generating process
x0 = rng.uniform(size=n)
X_mat = np.column_stack([np.ones(n), x0])
eps = rng.normal(scale=true_sigma, size=n)
y_vec = X_mat @ true_beta + eps
# define beta
beta_prior = lsl.Dist(tfd.Normal, loc=0.0, scale=100.0)
beta = lsl.Var.new_param(value=jnp.array([0.0, 0.0]), dist=beta_prior, name="beta")
# define the variance and the scale
a = lsl.Var.new_param(0.01, name="a")
b = lsl.Var.new_param(0.01, name="b")
sigma_sq_prior = lsl.Dist(tfd.InverseGamma, concentration=a, scale=b)
sigma_sq = lsl.Var.new_param(value=1.0, dist=sigma_sq_prior, name="sigma_sq")
# Define sigma as a transformation of sigma_sq for the likelihood
sigma = lsl.Var.new_calc(jnp.sqrt, sigma_sq, name="sigma")
# calculator-setup
X = lsl.Var.new_obs(X_mat, name="X")
mu = lsl.Var.new_calc(jnp.dot, X, beta, name="mu")
# Build response
y_dist = lsl.Dist(tfd.Normal, loc=mu, scale=sigma)
y = lsl.Var.new_obs(y_vec, dist=y_dist, name="y")
# Plot model
model = lsl.Model([y])
model.plot()

MCMC inference#
Using a Gibbs kernel#
This time we want to sample the previously fixed sigma_sq with a Gibbs
sampler. Using a Gibbs kernel is a bit more complicated, because Goose
doesn’t automatically derive the full conditional from the model graph.
Hence, the user needs to provide a function to sample from the full
conditional. The function needs to accept a PRNG key and a model state
as arguments, and it needs to return a dictionary with the variable name
as the key and the new variable value as the value. We could also update
multiple parameters with one Gibbs kernel by returning a dictionary with
several entries.
For this normal-inverse-gamma model, the full conditional of \(\sigma^2\)
is again an inverse-gamma distribution. To retrieve the relevant values
from the model_state, we use Model.extract_position().
def draw_sigma_sq(prng_key, model_state):
# extract relevant values from model state
pos = model.extract_position(
position_keys=["y", "mu", "sigma_sq", "a", "b"], model_state=model_state
)
# calculate relevant intermediate quantities
n = len(pos["y"])
resid = pos["y"] - pos["mu"]
a_gibbs = pos["a"] + n / 2
b_gibbs = pos["b"] + jnp.sum(resid**2) / 2
# draw new value from full conditional
draw = b_gibbs / jax.random.gamma(prng_key, a_gibbs)
# return key-value pair of variable name and new value
return {"sigma_sq": draw}
The regression coefficients beta are still sampled with NUTS. For
sigma_sq, we attach an MCMCSpec with
with_transition_fn(), which turns our custom
transition function into a kernel factory that LieselMCMC can
use. The Gibbs kernel itself does not need adaptation, but the NUTS
kernel for beta does, so we still run an adaptation phase before
drawing posterior samples.
beta.inference = gs.MCMCSpec(gs.NUTSKernel)
sigma_sq.inference = gs.MCMCSpec(gs.GibbsKernel.with_transition_fn(draw_sigma_sq))
results = gs.LieselMCMC(model).run_for_epochs(
seed=1, num_chains=4, adaptation=1000, posterior=1000
)
liesel.goose.mcmc_spec - WARNING - No inference specification defined for Var(name="b"). If you do not add a kernel for this parameter manually to an EngineBuilder, it will not be sampled.
liesel.goose.mcmc_spec - WARNING - No inference specification defined for Var(name="a"). If you do not add a kernel for this parameter manually to an EngineBuilder, it will not be sampled.
liesel.goose.builder - WARNING - No jitter functions provided for position keys 'sigma_sq', 'beta'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 100 transitions, 25 jitted together
0%| | 0/4 [00:00<?, ?chunk/s] 25%|██████████▌ | 1/4 [00:01<00:04, 1.65s/chunk]100%|██████████████████████████████████████████| 4/4 [00:01<00:00, 2.42chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 4, 5, 4, 2 / 100 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
0%| | 0/1 [00:00<?, ?chunk/s]100%|█████████████████████████████████████████| 1/1 [00:00<00:00, 928.35chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 1, 1, 1 / 25 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
0%| | 0/2 [00:00<?, ?chunk/s]100%|████████████████████████████████████████| 2/2 [00:00<00:00, 1821.23chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 2, 1, 1 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
0%| | 0/4 [00:00<?, ?chunk/s]100%|████████████████████████████████████████| 4/4 [00:00<00:00, 1747.44chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 2, 1, 2, 1 / 100 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 525 transitions, 25 jitted together
0%| | 0/21 [00:00<?, ?chunk/s]100%|███████████████████████████████████████| 21/21 [00:00<00:00, 398.16chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 3, 2, 2, 4 / 525 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 200 transitions, 25 jitted together
0%| | 0/8 [00:00<?, ?chunk/s]100%|████████████████████████████████████████| 8/8 [00:00<00:00, 1529.02chunk/s]
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 4, 3, 1, 2 / 200 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 1000 transitions, 25 jitted together
0%| | 0/40 [00:00<?, ?chunk/s]100%|███████████████████████████████████████| 40/40 [00:00<00:00, 392.51chunk/s]100%|███████████████████████████████████████| 40/40 [00:00<00:00, 390.75chunk/s]
liesel.goose.engine - INFO - Finished epoch
Finally, we can take a look at our results.
summary = gs.Summary(results)
summary
Parameter summary:
| kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
|---|---|---|---|---|---|---|---|---|---|---|---|
| parameter | index | ||||||||||
| beta | (0,) | kernel_01 | 0.983 | 0.093 | 0.832 | 0.983 | 1.133 | 4000 | 1063.525 | 1171.567 | 1.005 |
| (1,) | kernel_01 | 1.912 | 0.159 | 1.652 | 1.912 | 2.172 | 4000 | 1009.363 | 1118.012 | 1.007 | |
| sigma_sq | () | kernel_00 | 1.043 | 0.066 | 0.939 | 1.041 | 1.154 | 4000 | 4090.345 | 4014.613 | 1.000 |
Acceptance probabilities:
| acceptance_probability | position_moved | |||
|---|---|---|---|---|
| kernel | positions | phase | ||
| kernel_00 | sigma_sq | posterior | 1.000 | 1.000 |
| warmup | 1.000 | 1.000 | ||
| kernel_01 | beta | posterior | 0.887 | NaN |
| warmup | 0.791 | NaN |
Error summary:
| count | sample_size | sample_size_total | relative | |||||
|---|---|---|---|---|---|---|---|---|
| kernel | positions | error_code | error_msg | phase | ||||
| kernel_01 | beta | 1 | divergent transition | warmup | 51 | 4000 | 4000 | 0.013 |
| posterior | 0 | 4000 | 4000 | 0.000 |
And plot them.
gs.plot_trace(results)
