Comparing samplers

Comparing samplers#

In this tutorial, we are comparing two different sampling schemes on the mcycle dataset with a Gaussian location-scale regression model and two splines for the mean and the standard deviation. The mcycle dataset is a “data frame giving a series of measurements of head acceleration in a simulated motorcycle accident, used to test crash helmets” (from the help page). It contains the following two variables:

  • times: in milliseconds after impact

  • accel: in g

We start off in R by loading the dataset and setting up the model with the rliesel::liesel() function.

library(MASS)
library(rliesel)

data(mcycle)
with(mcycle, plot(times, accel))

model <- liesel(
  response = mcycle$accel,
  distribution = "Normal",
  predictors = list(
    loc = predictor(~ s(times)),
    scale = predictor(~ s(times), inverse_link = "Exp")
  ),
  data = mcycle
)

Metropolis-in-Gibbs#

First, we try a Metropolis-in-Gibbs sampling scheme with IWLS kernels for the regression coefficients (\(\boldsymbol{\beta}\)) and Gibbs kernels for the smoothing parameters (\(\tau^2\)) of the splines.

import liesel.model as lsl

model = r.model

builder = lsl.dist_reg_mcmc(model, seed=42, num_chains=4)
builder.set_duration(warmup_duration=5000, posterior_duration=1000)
builder.show_progress = False

engine = builder.build()
engine.sample_all_epochs()

Clearly, the performance of the sampler could be better, especially for the intercept of the mean. The corresponding chain exhibits a very strong autocorrelation.

import liesel.goose as gs

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc_np0_beta (0,) kernel_04 -89.783 243.714 -483.961 -93.170 317.212 4000 35.650 419.354 1.079
(1,) kernel_04 -1450.583 251.260 -1876.271 -1440.334 -1060.724 4000 292.621 649.406 1.018
(2,) kernel_04 -685.581 170.859 -963.474 -682.322 -406.089 4000 232.319 413.269 1.017
(3,) kernel_04 -566.567 113.858 -748.643 -569.450 -375.760 4000 200.408 587.108 1.019
(4,) kernel_04 1117.895 94.023 963.678 1115.170 1269.938 4000 294.993 818.162 1.014
(5,) kernel_04 -64.185 34.569 -121.355 -62.629 -8.844 4000 95.479 288.695 1.043
(6,) kernel_04 -211.062 21.122 -246.101 -211.329 -176.134 4000 108.539 444.519 1.039
(7,) kernel_04 116.051 69.206 17.010 108.444 240.738 4000 164.199 356.351 1.010
(8,) kernel_04 30.198 17.908 4.629 27.936 62.939 4000 143.481 301.179 1.017
loc_np0_tau2 () kernel_03 736267.500 576384.875 254046.856 586045.219 1712081.537 4000 1716.899 2508.099 1.001
loc_p0_beta (0,) kernel_05 -24.772 2.514 -28.690 -24.817 -20.496 4000 9.321 19.502 1.380
scale_np0_beta (0,) kernel_01 6.953 9.373 -5.137 5.156 24.396 4000 19.041 128.454 1.145
(1,) kernel_01 -2.119 6.417 -13.762 -1.500 7.521 4000 24.990 63.433 1.122
(2,) kernel_01 -16.403 9.940 -33.269 -16.295 -1.356 4000 15.048 38.506 1.220
(3,) kernel_01 9.526 4.653 2.741 9.110 17.955 4000 17.114 127.404 1.192
(4,) kernel_01 -2.031 3.875 -8.943 -1.594 3.610 4000 34.877 97.160 1.084
(5,) kernel_01 3.733 1.977 0.723 3.572 6.974 4000 22.106 43.125 1.144
(6,) kernel_01 -0.032 2.937 -5.476 0.459 4.034 4000 13.837 41.847 1.203
(7,) kernel_01 -0.828 3.341 -5.853 -1.107 4.954 4000 56.561 173.938 1.065
(8,) kernel_01 -0.929 1.772 -4.211 -0.641 1.491 4000 16.101 87.743 1.178
scale_np0_tau2 () kernel_00 121.644 171.814 7.737 72.676 393.397 4000 16.500 109.850 1.187
scale_p0_beta (0,) kernel_02 2.773 0.069 2.661 2.773 2.888 4000 258.903 639.815 1.033

Error summary:

count relative
kernel error_code error_msg phase
kernel_01 90 nan acceptance prob warmup 11 0.001
posterior 0 0.000
kernel_04 90 nan acceptance prob warmup 1 0.000
posterior 0 0.000
fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "scale_np0_tau2")

fig = gs.plot_trace(results, "scale_np0_beta")

To confirm that the chains have converged to reasonable values, here is a plot of the estimated mean function:

summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
Attaching package: 'dplyr'

The following object is masked from 'package:MASS':

    select

The following objects are masked from 'package:stats':

    filter, lag

The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(reticulate)
Warning: package 'reticulate' was built under R version 4.4.1
summary <- py$summary

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- py_to_r(model$vars["loc_np0_X"]$value)
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.

NUTS sampler#

As an alternative, we try using NUTS kernels for all parameters. To do so, we first need to log-transform the smoothing parameters. This is the model graph before the transformation:

lsl.plot_vars(model)

To transform the smoothing parameters with the method Var.transform(), we need to retrieve the nodes and vars form the model. This is necessary, because while they are part of a model, the inputs and outputs of nodes and vars cannot be changed. We retrieve the nodes and vars using Model.pop_nodes_and_vars(), which renders the model empty.

After transformation, there are two additional nodes in the new model graph.

import tensorflow_probability.substrates.jax.bijectors as tfb

nodes, _vars = model.pop_nodes_and_vars()

_vars["loc_np0_tau2"].transform(tfb.Exp())
Var(name="loc_np0_tau2_transformed")
_vars["scale_np0_tau2"].transform(tfb.Exp())
Var(name="scale_np0_tau2_transformed")
gb = lsl.GraphBuilder().add(_vars["response"])

model = gb.build_model()
lsl.plot_vars(model)

Now we can set up the NUTS sampler. In complex models like this one it can be very beneficial to use individual NUTS samplers for blocks of parameters. This is pretty much the same strategy that we apply to the IWLS sampler, too.

builder = gs.EngineBuilder(seed=42, num_chains=4)

builder.set_model(gs.LieselInterface(model))

# add NUTS kernels
parameters = [name for name, var in model.vars.items() if var.parameter]

for parameter in parameters:
  builder.add_kernel(gs.NUTSKernel([parameter]))


builder.set_initial_values(model.state)

builder.set_epochs(
  gs.stan_epochs(warmup_duration=5000, posterior_duration=1000, init_duration=750, term_duration=500)
)

builder.show_progress = False

engine = builder.build()
engine.sample_all_epochs()

The NUTS sampler overall seems to do a good job - and even yields higher effective sample sizes than the IWLS sampler, especially for the spline coefficients of the scale model.

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
loc_np0_beta (0,) kernel_03 -88.104 252.529 -498.737 -94.419 337.766 4000 474.180 1377.730 1.008
(1,) kernel_03 -1453.895 245.675 -1843.600 -1453.718 -1053.058 4000 765.684 1813.545 1.006
(2,) kernel_03 -681.106 182.549 -985.511 -682.477 -377.593 4000 1104.403 1789.338 1.001
(3,) kernel_03 -554.689 113.558 -735.240 -553.886 -365.407 4000 1009.067 1669.084 1.005
(4,) kernel_03 1126.253 98.042 968.061 1126.222 1288.899 4000 1400.525 1695.291 1.003
(5,) kernel_03 -65.458 34.845 -121.377 -65.759 -6.803 4000 365.669 798.602 1.010
(6,) kernel_03 -210.146 23.005 -245.842 -210.666 -172.913 4000 1112.555 1175.961 1.003
(7,) kernel_03 122.930 80.350 12.235 113.462 269.568 4000 602.146 586.165 1.005
(8,) kernel_03 31.984 21.059 3.020 29.346 70.037 4000 556.233 544.773 1.005
loc_np0_tau2_transformed () kernel_04 13.329 0.571 12.460 13.294 14.329 4000 1090.885 1274.243 1.003
loc_p0_beta (0,) kernel_05 -24.860 1.858 -27.923 -24.778 -21.861 4000 54.740 105.458 1.069
scale_np0_beta (0,) kernel_00 6.745 9.369 -6.052 5.247 24.883 4000 474.570 946.737 1.003
(1,) kernel_00 -1.733 6.367 -13.032 -1.304 8.003 4000 746.909 894.980 1.005
(2,) kernel_00 -15.173 8.991 -31.192 -14.309 -2.102 4000 305.664 1123.548 1.008
(3,) kernel_00 10.141 5.134 2.598 9.723 19.043 4000 303.817 718.883 1.009
(4,) kernel_00 -1.745 4.021 -8.910 -1.490 4.313 4000 523.670 1183.290 1.003
(5,) kernel_00 4.106 2.021 1.005 4.023 7.558 4000 385.347 639.357 1.009
(6,) kernel_00 0.222 2.756 -4.833 0.566 4.138 4000 302.514 943.799 1.005
(7,) kernel_00 0.062 4.119 -6.222 -0.243 7.083 4000 493.763 595.323 1.006
(8,) kernel_00 -0.678 1.709 -3.859 -0.432 1.655 4000 345.563 728.305 1.004
scale_np0_tau2_transformed () kernel_01 4.160 1.090 2.293 4.222 5.830 4000 226.354 481.191 1.012
scale_p0_beta (0,) kernel_02 2.772 0.068 2.664 2.770 2.886 4000 806.194 2050.150 1.001

Error summary:

count relative
kernel error_code error_msg phase
kernel_00 1 divergent transition warmup 1764 0.088
posterior 1 0.000
2 maximum tree depth warmup 75 0.004
posterior 0 0.000
kernel_01 1 divergent transition warmup 107 0.005
posterior 0 0.000
kernel_02 1 divergent transition warmup 63 0.003
posterior 0 0.000
kernel_03 1 divergent transition warmup 1150 0.057
posterior 66 0.016
2 maximum tree depth warmup 3441 0.172
posterior 119 0.030
kernel_04 1 divergent transition warmup 117 0.006
posterior 0 0.000
kernel_05 1 divergent transition warmup 145 0.007
posterior 1 0.000
fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2_transformed")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "scale_np0_tau2_transformed")

fig = gs.plot_trace(results, "scale_np0_beta")

Again, here is a plot of the estimated mean function:

summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
library(ggplot2)
library(reticulate)

summary <- py$summary
model <- py$model

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- model$vars["loc_np0_X"]$value
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()