Comparing samplers

Comparing samplers#

In this tutorial, we are comparing two different sampling schemes on the mcycle dataset with a Gaussian location-scale regression model and two splines for the mean and the standard deviation. The mcycle dataset is a “data frame giving a series of measurements of head acceleration in a simulated motorcycle accident, used to test crash helmets” (from the help page). It contains the following two variables:

  • times: in milliseconds after impact

  • accel: in g

We start off in R by loading the dataset and setting up the model with the rliesel::liesel() function.

library(MASS)
library(rliesel)
Please make sure you are using a virtual or conda environment with Liesel installed, e.g. using `reticulate::use_virtualenv()` or `reticulate::use_condaenv()`. See `vignette("versions", "reticulate")`.

After setting the environment, check if the installed versions of RLiesel and Liesel are compatible with `check_liesel_version()`.
data(mcycle)
with(mcycle, plot(times, accel))

model <- liesel(
  response = mcycle$accel,
  distribution = "Normal",
  predictors = list(
    loc = predictor(~ s(times)),
    scale = predictor(~ s(times), inverse_link = "Exp")
  ),
  data = mcycle
)
Installed Liesel version 0.2.10-dev is compatible, continuing to set up model

Metropolis-in-Gibbs#

First, we try a Metropolis-in-Gibbs sampling scheme with IWLS kernels for the regression coefficients (\(\boldsymbol{\beta}\)) and Gibbs kernels for the smoothing parameters (\(\tau^2\)) of the splines.

import liesel.model as lsl

model = r.model

builder = lsl.dist_reg_mcmc(model, seed=42, num_chains=4)
builder.set_duration(warmup_duration=5000, posterior_duration=1000)

engine = builder.build()
engine.sample_all_epochs()

Clearly, the performance of the sampler could be better, especially for the intercept of the mean. The corresponding chain exhibits a very strong autocorrelation.

import liesel.goose as gs

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc_np0_beta (0,) kernel_04 -22.073 242.590 -420.845 -22.455 391.772 4000 51.359 328.125 1.084
(1,) kernel_04 -1470.396 227.801 -1853.595 -1466.905 -1101.078 4000 227.695 567.559 1.016
(2,) kernel_04 -672.804 170.204 -954.507 -672.363 -394.065 4000 163.062 574.742 1.020
(3,) kernel_04 -578.864 111.028 -759.696 -578.867 -400.150 4000 230.468 629.516 1.009
(4,) kernel_04 1130.755 86.333 986.248 1132.358 1273.147 4000 372.603 727.101 1.004
(5,) kernel_04 -68.810 34.394 -126.301 -68.735 -11.545 4000 83.891 237.951 1.051
(6,) kernel_04 -213.497 20.291 -246.551 -213.732 -180.546 4000 271.528 468.387 1.021
(7,) kernel_04 114.058 65.370 16.265 110.296 228.006 4000 181.350 257.754 1.028
(8,) kernel_04 28.875 17.138 3.737 27.905 58.993 4000 172.313 276.003 1.024
loc_np0_tau2 () kernel_03 747036.375 528680.062 270306.500 608441.594 1687256.681 4000 1748.916 2601.569 1.001
loc_p0_beta (0,) kernel_05 -25.543 1.652 -28.474 -25.445 -22.769 4000 10.236 38.237 1.322
scale_np0_beta (0,) kernel_01 6.445 8.440 -3.077 3.844 23.779 4000 18.906 150.106 1.159
(1,) kernel_01 -0.798 5.463 -9.721 -0.742 7.866 4000 51.066 85.230 1.081
(2,) kernel_01 -11.898 9.353 -30.580 -10.475 -0.123 4000 13.333 25.570 1.270
(3,) kernel_01 8.310 4.650 1.507 7.647 16.789 4000 26.348 120.082 1.144
(4,) kernel_01 -1.273 3.804 -8.567 -0.712 4.079 4000 24.836 70.130 1.128
(5,) kernel_01 3.199 1.715 0.500 3.162 6.098 4000 46.873 59.511 1.071
(6,) kernel_01 0.893 2.923 -4.867 1.280 4.816 4000 14.427 59.989 1.223
(7,) kernel_01 -1.347 3.302 -6.438 -1.574 4.416 4000 75.501 128.461 1.046
(8,) kernel_01 -0.414 1.831 -3.999 0.034 1.880 4000 12.663 39.420 1.263
scale_np0_tau2 () kernel_00 86.308 135.919 6.482 41.332 308.744 4000 13.655 90.453 1.239
scale_p0_beta (0,) kernel_02 2.779 0.072 2.662 2.777 2.900 4000 44.596 608.667 1.065

Error summary:

count relative
kernel error_code error_msg phase
kernel_01 90 nan acceptance prob warmup 15 0.001
posterior 0 0.000
kernel_02 90 nan acceptance prob warmup 2 0.000
posterior 0 0.000
kernel_04 90 nan acceptance prob warmup 18 0.001
posterior 0 0.000
fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "scale_np0_tau2")

fig = gs.plot_trace(results, "scale_np0_beta")

To confirm that the chains have converged to reasonable values, here is a plot of the estimated mean function:

summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
Attaching package: 'dplyr'

The following object is masked from 'package:MASS':

    select

The following objects are masked from 'package:stats':

    filter, lag

The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(reticulate)

summary <- py$summary

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- py_to_r(model$vars["loc_np0_X"]$value)
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.

NUTS sampler#

As an alternative, we try a NUTS kernel which samples all model parameters (regression coefficients and smoothing parameters) in one block. To do so, we first need to log-transform the smoothing parameters. This is the model graph before the transformation:

lsl.plot_vars(model)

Before transforming the smoothing parameters with the lsl.transform_parameter() function, we first need to copy all model nodes. Once this is done, we need to update the output nodes of the smoothing parameters and rebuild the model. There are two additional nodes in the new model graph.

import tensorflow_probability.substrates.jax.bijectors as tfb

nodes, _vars = model.pop_nodes_and_vars()

gb = lsl.GraphBuilder()
gb.add(_vars["response"])
GraphBuilder(0 nodes, 1 vars)
_ = gb.transform(_vars["loc_np0_tau2"], tfb.Exp)
_ = gb.transform(_vars["scale_np0_tau2"], tfb.Exp)
model = gb.build_model()
lsl.plot_vars(model)

Now we can set up the NUTS sampler. In complex models like this one it can be very beneficial to use individual NUTS samplers for blocks of parameters. This is pretty much the same strategy that we apply to the IWLS sampler, too.


builder = gs.EngineBuilder(seed=42, num_chains=4)

builder.set_model(gs.LieselInterface(model))

# add NUTS kernels
parameters = [name for name, var in model.vars.items() if var.parameter]
for parameter in parameters:
  builder.add_kernel(gs.NUTSKernel([parameter]))


builder.set_initial_values(model.state)

builder.set_duration(warmup_duration=5000, posterior_duration=1000)

engine = builder.build()
engine.sample_all_epochs()

The NUTS sampler overall seems to do a good job - and even yields higher effective sample sizes than the IWLS sampler, especially for the spline coefficients of the scale model.

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc_np0_beta (0,) kernel_04 -58.968 244.904 -446.256 -63.482 350.972 4000 303.105 1106.189 1.010
(1,) kernel_04 -1452.305 240.587 -1854.228 -1445.285 -1062.321 4000 1237.965 1670.364 1.001
(2,) kernel_04 -679.038 174.487 -969.307 -678.779 -395.784 4000 880.418 1521.646 1.007
(3,) kernel_04 -565.712 111.768 -750.466 -567.981 -377.534 4000 1313.834 1494.371 1.003
(4,) kernel_04 1122.383 95.030 969.316 1120.986 1276.253 4000 1599.451 1671.529 1.003
(5,) kernel_04 -69.899 34.855 -124.631 -70.025 -13.924 4000 223.843 541.307 1.005
(6,) kernel_04 -210.811 22.367 -246.532 -211.490 -172.968 4000 1119.211 1752.065 1.003
(7,) kernel_04 113.849 72.297 9.920 106.721 244.043 4000 1017.327 849.913 1.003
(8,) kernel_04 29.644 19.091 2.267 27.771 63.949 4000 940.276 875.634 1.003
loc_np0_tau2_transformed () kernel_01 13.317 0.559 12.454 13.283 14.302 4000 1157.609 1434.248 1.005
loc_p0_beta (0,) kernel_05 -25.334 2.172 -28.720 -25.286 -21.895 4000 45.604 57.632 1.130
scale_np0_beta (0,) kernel_02 6.210 9.305 -5.335 4.210 24.434 4000 327.892 702.445 1.004
(1,) kernel_02 -1.433 6.164 -12.271 -1.055 7.977 4000 1002.500 938.664 1.003
(2,) kernel_02 -14.282 8.983 -30.375 -13.225 -1.864 4000 175.802 621.941 1.008
(3,) kernel_02 9.383 5.158 1.775 8.788 18.633 4000 260.234 1068.942 1.006
(4,) kernel_02 -1.532 3.911 -8.335 -1.210 4.324 4000 420.107 809.669 1.007
(5,) kernel_02 3.789 2.003 0.679 3.612 7.370 4000 311.801 1155.484 1.004
(6,) kernel_02 0.432 2.745 -4.628 0.886 4.188 4000 203.742 615.753 1.008
(7,) kernel_02 -0.531 3.899 -6.398 -0.800 6.485 4000 765.810 1120.062 1.004
(8,) kernel_02 -0.596 1.679 -3.793 -0.300 1.621 4000 221.605 614.913 1.009
scale_np0_tau2_transformed () kernel_00 4.016 1.157 2.158 4.046 5.854 4000 147.315 425.801 1.009
scale_p0_beta (0,) kernel_03 2.774 0.069 2.660 2.774 2.884 4000 475.142 1499.446 1.006

Error summary:

count relative
kernel error_code error_msg phase
kernel_00 1 divergent transition warmup 83 0.004
posterior 0 0.000
kernel_01 1 divergent transition warmup 107 0.005
posterior 0 0.000
kernel_02 1 divergent transition warmup 1441 0.072
posterior 0 0.000
2 maximum tree depth warmup 47 0.002
posterior 0 0.000
kernel_03 1 divergent transition warmup 60 0.003
posterior 0 0.000
kernel_04 1 divergent transition warmup 1036 0.052
posterior 1 0.000
2 maximum tree depth warmup 1783 0.089
posterior 458 0.115
kernel_05 1 divergent transition warmup 139 0.007
posterior 0 0.000
fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2_transformed")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "scale_np0_tau2_transformed")

fig = gs.plot_trace(results, "scale_np0_beta")

Again, here is a plot of the estimated mean function:

summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
library(ggplot2)
library(reticulate)

summary <- py$summary
model <- py$model

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- model$vars["loc_np0_X"]$value
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()