Comparing samplers#
In this tutorial, we are comparing two different sampling schemes on the
mcycle
dataset with a Gaussian location-scale regression model and two
splines for the mean and the standard deviation. The mcycle
dataset is
a “data frame giving a series of measurements of head acceleration in a
simulated motorcycle accident, used to test crash helmets” (from the
help page). It contains the following two variables:
times
: in milliseconds after impactaccel
: in g
We start off in R by loading the dataset and setting up the model with
the rliesel::liesel()
function.
library(MASS)
library(rliesel)
Please make sure you are using a virtual or conda environment with Liesel installed, e.g. using `reticulate::use_virtualenv()` or `reticulate::use_condaenv()`. See `vignette("versions", "reticulate")`.
After setting the environment, check if the installed versions of RLiesel and Liesel are compatible with `check_liesel_version()`.
data(mcycle)
with(mcycle, plot(times, accel))
model <- liesel(
response = mcycle$accel,
distribution = "Normal",
predictors = list(
loc = predictor(~ s(times)),
scale = predictor(~ s(times), inverse_link = "Exp")
),
data = mcycle
)
Installed Liesel version 0.2.10-dev is compatible, continuing to set up model
Metropolis-in-Gibbs#
First, we try a Metropolis-in-Gibbs sampling scheme with IWLS kernels for the regression coefficients (\(\boldsymbol{\beta}\)) and Gibbs kernels for the smoothing parameters (\(\tau^2\)) of the splines.
import liesel.model as lsl
model = r.model
builder = lsl.dist_reg_mcmc(model, seed=42, num_chains=4)
builder.set_duration(warmup_duration=5000, posterior_duration=1000)
engine = builder.build()
engine.sample_all_epochs()
Clearly, the performance of the sampler could be better, especially for the intercept of the mean. The corresponding chain exhibits a very strong autocorrelation.
import liesel.goose as gs
results = engine.get_results()
gs.Summary(results)
Parameter summary:
kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
---|---|---|---|---|---|---|---|---|---|---|---|
parameter | index | ||||||||||
loc_np0_beta | (0,) | kernel_04 | -22.073 | 242.590 | -420.845 | -22.455 | 391.772 | 4000 | 51.359 | 328.125 | 1.084 |
(1,) | kernel_04 | -1470.396 | 227.801 | -1853.595 | -1466.905 | -1101.078 | 4000 | 227.695 | 567.559 | 1.016 | |
(2,) | kernel_04 | -672.804 | 170.204 | -954.507 | -672.363 | -394.065 | 4000 | 163.062 | 574.742 | 1.020 | |
(3,) | kernel_04 | -578.864 | 111.028 | -759.696 | -578.867 | -400.150 | 4000 | 230.468 | 629.516 | 1.009 | |
(4,) | kernel_04 | 1130.755 | 86.333 | 986.248 | 1132.358 | 1273.147 | 4000 | 372.603 | 727.101 | 1.004 | |
(5,) | kernel_04 | -68.810 | 34.394 | -126.301 | -68.735 | -11.545 | 4000 | 83.891 | 237.951 | 1.051 | |
(6,) | kernel_04 | -213.497 | 20.291 | -246.551 | -213.732 | -180.546 | 4000 | 271.528 | 468.387 | 1.021 | |
(7,) | kernel_04 | 114.058 | 65.370 | 16.265 | 110.296 | 228.006 | 4000 | 181.350 | 257.754 | 1.028 | |
(8,) | kernel_04 | 28.875 | 17.138 | 3.737 | 27.905 | 58.993 | 4000 | 172.313 | 276.003 | 1.024 | |
loc_np0_tau2 | () | kernel_03 | 747036.375 | 528680.062 | 270306.500 | 608441.594 | 1687256.681 | 4000 | 1748.916 | 2601.569 | 1.001 |
loc_p0_beta | (0,) | kernel_05 | -25.543 | 1.652 | -28.474 | -25.445 | -22.769 | 4000 | 10.236 | 38.237 | 1.322 |
scale_np0_beta | (0,) | kernel_01 | 6.445 | 8.440 | -3.077 | 3.844 | 23.779 | 4000 | 18.906 | 150.106 | 1.159 |
(1,) | kernel_01 | -0.798 | 5.463 | -9.721 | -0.742 | 7.866 | 4000 | 51.066 | 85.230 | 1.081 | |
(2,) | kernel_01 | -11.898 | 9.353 | -30.580 | -10.475 | -0.123 | 4000 | 13.333 | 25.570 | 1.270 | |
(3,) | kernel_01 | 8.310 | 4.650 | 1.507 | 7.647 | 16.789 | 4000 | 26.348 | 120.082 | 1.144 | |
(4,) | kernel_01 | -1.273 | 3.804 | -8.567 | -0.712 | 4.079 | 4000 | 24.836 | 70.130 | 1.128 | |
(5,) | kernel_01 | 3.199 | 1.715 | 0.500 | 3.162 | 6.098 | 4000 | 46.873 | 59.511 | 1.071 | |
(6,) | kernel_01 | 0.893 | 2.923 | -4.867 | 1.280 | 4.816 | 4000 | 14.427 | 59.989 | 1.223 | |
(7,) | kernel_01 | -1.347 | 3.302 | -6.438 | -1.574 | 4.416 | 4000 | 75.501 | 128.461 | 1.046 | |
(8,) | kernel_01 | -0.414 | 1.831 | -3.999 | 0.034 | 1.880 | 4000 | 12.663 | 39.420 | 1.263 | |
scale_np0_tau2 | () | kernel_00 | 86.308 | 135.919 | 6.482 | 41.332 | 308.744 | 4000 | 13.655 | 90.453 | 1.239 |
scale_p0_beta | (0,) | kernel_02 | 2.779 | 0.072 | 2.662 | 2.777 | 2.900 | 4000 | 44.596 | 608.667 | 1.065 |
Error summary:
count | relative | ||||
---|---|---|---|---|---|
kernel | error_code | error_msg | phase | ||
kernel_01 | 90 | nan acceptance prob | warmup | 15 | 0.001 |
posterior | 0 | 0.000 | |||
kernel_02 | 90 | nan acceptance prob | warmup | 2 | 0.000 |
posterior | 0 | 0.000 | |||
kernel_04 | 90 | nan acceptance prob | warmup | 18 | 0.001 |
posterior | 0 | 0.000 |
fig = gs.plot_trace(results, "loc_p0_beta")
fig = gs.plot_trace(results, "loc_np0_tau2")
fig = gs.plot_trace(results, "loc_np0_beta")
fig = gs.plot_trace(results, "scale_p0_beta")
fig = gs.plot_trace(results, "scale_np0_tau2")
fig = gs.plot_trace(results, "scale_np0_beta")
To confirm that the chains have converged to reasonable values, here is a plot of the estimated mean function:
summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
Attaching package: 'dplyr'
The following object is masked from 'package:MASS':
select
The following objects are masked from 'package:stats':
filter, lag
The following objects are masked from 'package:base':
intersect, setdiff, setequal, union
library(ggplot2)
library(reticulate)
summary <- py$summary
beta <- summary %>%
filter(variable == "loc_np0_beta") %>%
group_by(var_index) %>%
summarize(mean = mean(mean)) %>%
ungroup()
beta <- beta$mean
X <- py_to_r(model$vars["loc_np0_X"]$value)
f <- X %*% beta
beta0 <- summary %>%
filter(variable == "loc_p0_beta") %>%
group_by(var_index) %>%
summarize(mean = mean(mean)) %>%
ungroup()
beta0 <- beta0$mean
ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
geom_line(aes(times, mean), color = palette()[2], size = 1) +
geom_point(aes(times, accel), data = mcycle) +
ggtitle("Estimated mean function") +
theme_minimal()
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.
NUTS sampler#
As an alternative, we try a NUTS kernel which samples all model parameters (regression coefficients and smoothing parameters) in one block. To do so, we first need to log-transform the smoothing parameters. This is the model graph before the transformation:
lsl.plot_vars(model)
Before transforming the smoothing parameters with the
lsl.transform_parameter()
function, we first need to copy all model
nodes. Once this is done, we need to update the output nodes of the
smoothing parameters and rebuild the model. There are two additional
nodes in the new model graph.
import tensorflow_probability.substrates.jax.bijectors as tfb
nodes, _vars = model.pop_nodes_and_vars()
gb = lsl.GraphBuilder()
gb.add(_vars["response"])
GraphBuilder(0 nodes, 1 vars)
_ = gb.transform(_vars["loc_np0_tau2"], tfb.Exp)
_ = gb.transform(_vars["scale_np0_tau2"], tfb.Exp)
model = gb.build_model()
lsl.plot_vars(model)
Now we can set up the NUTS sampler. In complex models like this one it can be very beneficial to use individual NUTS samplers for blocks of parameters. This is pretty much the same strategy that we apply to the IWLS sampler, too.
builder = gs.EngineBuilder(seed=42, num_chains=4)
builder.set_model(gs.LieselInterface(model))
# add NUTS kernels
parameters = [name for name, var in model.vars.items() if var.parameter]
for parameter in parameters:
builder.add_kernel(gs.NUTSKernel([parameter]))
builder.set_initial_values(model.state)
builder.set_duration(warmup_duration=5000, posterior_duration=1000)
engine = builder.build()
engine.sample_all_epochs()
The NUTS sampler overall seems to do a good job - and even yields higher effective sample sizes than the IWLS sampler, especially for the spline coefficients of the scale model.
results = engine.get_results()
gs.Summary(results)
Parameter summary:
kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
---|---|---|---|---|---|---|---|---|---|---|---|
parameter | index | ||||||||||
loc_np0_beta | (0,) | kernel_04 | -58.968 | 244.904 | -446.256 | -63.482 | 350.972 | 4000 | 303.105 | 1106.189 | 1.010 |
(1,) | kernel_04 | -1452.305 | 240.587 | -1854.228 | -1445.285 | -1062.321 | 4000 | 1237.965 | 1670.364 | 1.001 | |
(2,) | kernel_04 | -679.038 | 174.487 | -969.307 | -678.779 | -395.784 | 4000 | 880.418 | 1521.646 | 1.007 | |
(3,) | kernel_04 | -565.712 | 111.768 | -750.466 | -567.981 | -377.534 | 4000 | 1313.834 | 1494.371 | 1.003 | |
(4,) | kernel_04 | 1122.383 | 95.030 | 969.316 | 1120.986 | 1276.253 | 4000 | 1599.451 | 1671.529 | 1.003 | |
(5,) | kernel_04 | -69.899 | 34.855 | -124.631 | -70.025 | -13.924 | 4000 | 223.843 | 541.307 | 1.005 | |
(6,) | kernel_04 | -210.811 | 22.367 | -246.532 | -211.490 | -172.968 | 4000 | 1119.211 | 1752.065 | 1.003 | |
(7,) | kernel_04 | 113.849 | 72.297 | 9.920 | 106.721 | 244.043 | 4000 | 1017.327 | 849.913 | 1.003 | |
(8,) | kernel_04 | 29.644 | 19.091 | 2.267 | 27.771 | 63.949 | 4000 | 940.276 | 875.634 | 1.003 | |
loc_np0_tau2_transformed | () | kernel_01 | 13.317 | 0.559 | 12.454 | 13.283 | 14.302 | 4000 | 1157.609 | 1434.248 | 1.005 |
loc_p0_beta | (0,) | kernel_05 | -25.334 | 2.172 | -28.720 | -25.286 | -21.895 | 4000 | 45.604 | 57.632 | 1.130 |
scale_np0_beta | (0,) | kernel_02 | 6.210 | 9.305 | -5.335 | 4.210 | 24.434 | 4000 | 327.892 | 702.445 | 1.004 |
(1,) | kernel_02 | -1.433 | 6.164 | -12.271 | -1.055 | 7.977 | 4000 | 1002.500 | 938.664 | 1.003 | |
(2,) | kernel_02 | -14.282 | 8.983 | -30.375 | -13.225 | -1.864 | 4000 | 175.802 | 621.941 | 1.008 | |
(3,) | kernel_02 | 9.383 | 5.158 | 1.775 | 8.788 | 18.633 | 4000 | 260.234 | 1068.942 | 1.006 | |
(4,) | kernel_02 | -1.532 | 3.911 | -8.335 | -1.210 | 4.324 | 4000 | 420.107 | 809.669 | 1.007 | |
(5,) | kernel_02 | 3.789 | 2.003 | 0.679 | 3.612 | 7.370 | 4000 | 311.801 | 1155.484 | 1.004 | |
(6,) | kernel_02 | 0.432 | 2.745 | -4.628 | 0.886 | 4.188 | 4000 | 203.742 | 615.753 | 1.008 | |
(7,) | kernel_02 | -0.531 | 3.899 | -6.398 | -0.800 | 6.485 | 4000 | 765.810 | 1120.062 | 1.004 | |
(8,) | kernel_02 | -0.596 | 1.679 | -3.793 | -0.300 | 1.621 | 4000 | 221.605 | 614.913 | 1.009 | |
scale_np0_tau2_transformed | () | kernel_00 | 4.016 | 1.157 | 2.158 | 4.046 | 5.854 | 4000 | 147.315 | 425.801 | 1.009 |
scale_p0_beta | (0,) | kernel_03 | 2.774 | 0.069 | 2.660 | 2.774 | 2.884 | 4000 | 475.142 | 1499.446 | 1.006 |
Error summary:
count | relative | ||||
---|---|---|---|---|---|
kernel | error_code | error_msg | phase | ||
kernel_00 | 1 | divergent transition | warmup | 83 | 0.004 |
posterior | 0 | 0.000 | |||
kernel_01 | 1 | divergent transition | warmup | 107 | 0.005 |
posterior | 0 | 0.000 | |||
kernel_02 | 1 | divergent transition | warmup | 1441 | 0.072 |
posterior | 0 | 0.000 | |||
2 | maximum tree depth | warmup | 47 | 0.002 | |
posterior | 0 | 0.000 | |||
kernel_03 | 1 | divergent transition | warmup | 60 | 0.003 |
posterior | 0 | 0.000 | |||
kernel_04 | 1 | divergent transition | warmup | 1036 | 0.052 |
posterior | 1 | 0.000 | |||
2 | maximum tree depth | warmup | 1783 | 0.089 | |
posterior | 458 | 0.115 | |||
kernel_05 | 1 | divergent transition | warmup | 139 | 0.007 |
posterior | 0 | 0.000 |
fig = gs.plot_trace(results, "loc_p0_beta")
fig = gs.plot_trace(results, "loc_np0_tau2_transformed")
fig = gs.plot_trace(results, "loc_np0_beta")
fig = gs.plot_trace(results, "scale_p0_beta")
fig = gs.plot_trace(results, "scale_np0_tau2_transformed")
fig = gs.plot_trace(results, "scale_np0_beta")
Again, here is a plot of the estimated mean function:
summary = gs.Summary(results).to_dataframe().reset_index()
library(dplyr)
library(ggplot2)
library(reticulate)
summary <- py$summary
model <- py$model
beta <- summary %>%
filter(variable == "loc_np0_beta") %>%
group_by(var_index) %>%
summarize(mean = mean(mean)) %>%
ungroup()
beta <- beta$mean
X <- model$vars["loc_np0_X"]$value
f <- X %*% beta
beta0 <- summary %>%
filter(variable == "loc_p0_beta") %>%
group_by(var_index) %>%
summarize(mean = mean(mean)) %>%
ungroup()
beta0 <- beta0$mean
ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
geom_line(aes(times, mean), color = palette()[2], size = 1) +
geom_point(aes(times, accel), data = mcycle) +
ggtitle("Estimated mean function") +
theme_minimal()