Comparing samplers#

In this tutorial, we are comparing two different sampling schemes on the mcycle dataset with a Gaussian location-scale regression model and two splines for the mean and the standard deviation. The mcycle dataset is a “data frame giving a series of measurements of head acceleration in a simulated motorcycle accident, used to test crash helmets” (from the help page). It contains the following two variables:

  • times: in milliseconds after impact

  • accel: in g

We start off in R by loading the dataset and setting up the model with the rliesel::liesel() function.

library(MASS)
library(rliesel)
Please set your Liesel venv, e.g. with use_liesel_venv()
data(mcycle)
with(mcycle, plot(times, accel))

model <- liesel(
  response = mcycle$accel,
  distribution = "Normal",
  predictors = list(
    loc = predictor(~ s(times)),
    scale = predictor(~ s(times), inverse_link = "Exp")
  ),
  data = mcycle
)

Metropolis-in-Gibbs#

First, we try a Metropolis-in-Gibbs sampling scheme with IWLS kernels for the regression coefficients (\(\boldsymbol{\beta}\)) and Gibbs kernels for the smoothing parameters (\(\tau^2\)) of the splines.

import liesel.model as lsl

model = r.model

builder = lsl.dist_reg_mcmc(model, seed=42, num_chains=4)
builder.set_duration(warmup_duration=5000, posterior_duration=1000)

engine = builder.build()
engine.sample_all_epochs()
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 75 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 1, 1, 0 / 75 transitions
liesel.goose.engine - WARNING - Errors per chain for kernel_02: 1, 1, 1, 1 / 75 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 0, 0, 1 / 25 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 0, 1, 0 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 200 transitions, 25 jitted together
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 400 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 0, 1, 1 / 400 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 800 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 2, 0, 1, 1 / 800 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 3300 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 0, 1, 0 / 3300 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 50 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 1, 0, 0 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 1000 transitions, 25 jitted together
liesel.goose.engine - INFO - Finished epoch

Clearly, the performance of the sampler could be better, especially for the intercept of the mean. The corresponding chain exhibits a very strong autocorrelation.

import liesel.goose as gs

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc_np0_beta (0,) kernel_04 -124.212 251.580 -529.753 -138.365 308.680 4000 32.987 307.642 1.091
(1,) kernel_04 -1435.746 240.898 -1851.249 -1428.723 -1052.497 4000 159.682 321.332 1.020
(2,) kernel_04 -713.389 172.243 -1000.340 -708.812 -429.617 4000 143.301 717.462 1.029
(3,) kernel_04 -565.488 105.964 -739.408 -564.218 -394.887 4000 339.420 714.099 1.022
(4,) kernel_04 -1127.937 93.543 -1283.683 -1125.710 -977.021 4000 290.353 497.903 1.017
(5,) kernel_04 -58.789 32.544 -111.688 -59.484 -5.553 4000 159.610 350.236 1.023
(6,) kernel_04 -213.567 21.002 -247.189 -213.971 -179.385 4000 107.176 607.907 1.045
(7,) kernel_04 114.820 70.576 14.816 108.480 236.032 4000 150.660 182.166 1.025
(8,) kernel_04 30.100 18.331 4.179 28.688 60.796 4000 134.458 177.735 1.037
loc_np0_tau2 () kernel_03 731229.875 558309.375 252463.623 588973.625 1663368.350 4000 1343.143 2789.861 1.002
loc_p0_beta (0,) kernel_05 -23.916 1.696 -26.908 -23.864 -21.175 4000 11.016 17.064 1.301
scale_np0_beta (0,) kernel_01 8.319 10.861 -6.473 5.872 28.038 4000 27.849 162.492 1.132
(1,) kernel_01 -1.327 5.916 -11.243 -1.130 8.340 4000 187.760 414.387 1.014
(2,) kernel_01 -17.740 9.713 -33.825 -17.967 -2.879 4000 22.427 118.458 1.150
(3,) kernel_01 11.025 5.177 2.943 11.128 19.322 4000 36.072 197.788 1.095
(4,) kernel_01 2.737 4.282 -4.267 2.683 10.006 4000 52.878 270.644 1.083
(5,) kernel_01 4.073 1.908 1.057 3.961 7.200 4000 22.005 89.694 1.132
(6,) kernel_01 -0.655 3.196 -6.204 -0.547 4.116 4000 19.208 117.648 1.162
(7,) kernel_01 -0.371 3.620 -5.948 -0.707 5.960 4000 95.354 139.741 1.033
(8,) kernel_01 -1.254 1.946 -4.583 -1.183 1.724 4000 21.570 133.197 1.149
scale_np0_tau2 () kernel_00 141.122 164.236 11.544 88.833 435.435 4000 24.059 180.704 1.130
scale_p0_beta (0,) kernel_02 2.762 0.070 2.652 2.761 2.878 4000 194.521 942.988 1.024

Error summary:

count relative
kernel error_code error_msg phase
kernel_01 90 nan acceptance prob warmup 14 0.001
posterior 0 0.000
kernel_02 90 nan acceptance prob warmup 4 0.000
posterior 0 0.000
fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "scale_np0_tau2")

fig = gs.plot_trace(results, "scale_np0_beta")

To confirm that the chains have converged to reasonable values, here is a plot of the estimated mean function:

summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
Attaching package: 'dplyr'

The following object is masked from 'package:MASS':

    select

The following objects are masked from 'package:stats':

    filter, lag

The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(reticulate)

summary <- py$summary

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- py_to_r(model$vars["loc_np0_X"]$value)
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()

NUTS sampler#

As an alternative, we try a NUTS kernel which samples all model parameters (regression coefficients and smoothing parameters) in one block. To do so, we first need to log-transform the smoothing parameters. This is the model graph before the transformation:

lsl.plot_vars(model)

Before transforming the smoothing parameters with the lsl.transform_parameter() function, we first need to copy all model nodes. Once this is done, we need to update the output nodes of the smoothing parameters and rebuild the model. There are two additional nodes in the new model graph.

import tensorflow_probability.substrates.jax.bijectors as tfb

nodes, _vars = model.pop_nodes_and_vars()

gb = lsl.GraphBuilder()
gb.add(_vars["response"])
GraphBuilder(0 nodes, 1 vars)
_ = gb.transform(_vars["loc_np0_tau2"], tfb.Exp)
_ = gb.transform(_vars["scale_np0_tau2"], tfb.Exp)
model = gb.build_model()
lsl.plot_vars(model)

Now we can set up the NUTS sampler, which is straightforward because we are using only one kernel.

parameters = [name for name, var in model.vars.items() if var.parameter]

builder = gs.EngineBuilder(seed=42, num_chains=4)

builder.set_model(lsl.GooseModel(model))
builder.add_kernel(gs.NUTSKernel(parameters))
builder.set_initial_values(model.state)

builder.set_duration(warmup_duration=5000, posterior_duration=1000)

engine = builder.build()
engine.sample_all_epochs()
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 75 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 45, 31, 51, 61 / 75 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 9, 21, 13, 8 / 25 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 40, 50, 48, 46 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 93, 93, 91, 91 / 100 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 200 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 196, 196, 194, 193 / 200 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 400 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 384, 373, 385, 388 / 400 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 800 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 780, 750, 759, 780 / 800 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 3300 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 3178, 3143, 3233, 3218 / 3300 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 50 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 47, 47, 50, 47 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 1000 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 996, 997, 997, 999 / 1000 transitions
liesel.goose.engine - INFO - Finished epoch

The results are mixed. On the one hand, the NUTS sampler performs much better on the intercepts (for both the mean and the standard deviation), but on the other hand, the Metropolis-in-Gibbs sampler with the IWLS kernels seems to work better for the spline coefficients.

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc_np0_beta (0,) kernel_00 73.404 97.621 -10.697 30.962 245.625 4000 4.510 16.811 3.303
(1,) kernel_00 -48.960 63.900 -153.584 -37.778 27.514 4000 4.356 10.920 3.838
(2,) kernel_00 -194.597 129.507 -380.861 -202.929 -16.552 4000 4.477 11.572 3.309
(3,) kernel_00 -20.095 113.805 -192.847 -10.972 216.406 4000 4.874 11.445 2.427
(4,) kernel_00 -643.818 95.355 -817.823 -639.419 -508.709 4000 5.810 12.467 1.843
(5,) kernel_00 -87.386 47.478 -169.050 -74.672 -23.117 4000 5.690 34.393 1.871
(6,) kernel_00 -119.486 22.295 -157.998 -116.731 -85.191 4000 8.723 20.529 1.385
(7,) kernel_00 23.230 41.924 -32.190 20.523 111.299 4000 4.672 11.520 2.773
(8,) kernel_00 17.792 10.230 4.953 15.089 39.484 4000 5.319 12.053 2.141
loc_np0_tau2_transformed () kernel_00 11.193 0.619 10.228 11.166 12.263 4000 19.113 53.685 1.135
loc_p0_beta (0,) kernel_00 -17.161 3.709 -22.990 -17.380 -10.733 4000 8.171 29.605 1.438
scale_np0_beta (0,) kernel_00 -9.185 2.563 -12.943 -9.488 -5.041 4000 4.603 11.384 2.882
(1,) kernel_00 4.982 7.403 -7.035 4.680 17.179 4000 22.607 129.759 1.118
(2,) kernel_00 -15.595 7.542 -28.746 -15.147 -3.759 4000 21.209 281.072 1.121
(3,) kernel_00 15.539 5.549 6.477 15.549 24.623 4000 33.278 267.248 1.078
(4,) kernel_00 7.314 4.333 0.200 7.303 14.505 4000 36.979 259.509 1.074
(5,) kernel_00 6.554 2.209 2.885 6.585 10.175 4000 19.456 98.332 1.138
(6,) kernel_00 0.914 2.236 -2.933 1.003 4.415 4000 31.984 1112.830 1.082
(7,) kernel_00 2.397 4.373 -4.488 2.261 9.769 4000 42.758 1938.438 1.062
(8,) kernel_00 -0.319 1.254 -2.496 -0.259 1.639 4000 70.290 1509.826 1.040
scale_np0_tau2_transformed () kernel_00 4.703 0.779 3.404 4.708 5.981 4000 32.376 45.661 1.080
scale_p0_beta (0,) kernel_00 3.004 0.073 2.887 3.003 3.128 4000 24.727 349.860 1.105

Error summary:

count relative
kernel error_code error_msg phase
kernel_00 1 divergent transition warmup 2595 0.130
posterior 0 0.000
2 maximum tree depth warmup 15741 0.787
posterior 3989 0.997
3 divergent transition + maximum tree depth warmup 796 0.040
posterior 0 0.000
fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2_transformed")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "scale_np0_tau2_transformed")

fig = gs.plot_trace(results, "scale_np0_beta")

Again, here is a plot of the estimated mean function:

summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
library(ggplot2)
library(reticulate)

summary <- py$summary
model <- py$model

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- model$vars["loc_np0_X"]$value
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()