Comparing samplers

Comparing samplers#

In this tutorial, we compare two different sampling schemes on the mcycle dataset with a Gaussian location-scale regression model and two splines for the mean and the standard deviation. The mcycle dataset is a “data frame giving a series of measurements of head acceleration in a simulated motorcycle accident, used to test crash helmets” (from the help page). It contains the following two variables:

  • times: in milliseconds after impact

  • accel: in g

We set up the model in Python with Liesel-GAM, using liesel_gam.TermBuilder for the P-spline terms. See the Liesel-GAM documentation and examples for more information about additive terms and predictors. We load the data set from R with ryp and then continue with a pure Python model specification and sampling workflow.

from pathlib import Path

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel.goose as gs
import liesel.model as lsl
import liesel_gam as gam
from ryp import r, to_py
Warning message:
package ‘arrow’ was built under R version 4.5.2

We start by loading the data set from the R package MASS and converting it to a pandas data frame.

r("library(MASS)")
r("data(mcycle); mcycle <- as.data.frame(mcycle)")

mcycle = to_py("mcycle", format="pandas")
fig, ax = plt.subplots(figsize=(8, 4))
sns.scatterplot(data=mcycle, x="times", y="accel", color="0.25", s=35, ax=ax)
ax.set(xlabel="time after impact", ylabel="acceleration", title="mcycle data")
plt.show()

../../_images/data-output-1.png

Next, we build the Gaussian location-scale model. Both distributional parameters use an additive predictor with an intercept and a P-spline in times. The scale predictor uses an exponential inverse link to keep the standard deviation positive. The TermBuilder is initialized with an IWLS MCMC specification, so the P-spline regression coefficients are sampled with IWLS kernels by default. The additive predictor intercepts also use their default IWLS inference specification. The smoothing variances of the P-splines receive Gibbs kernels automatically.

tb = gam.TermBuilder.from_df(mcycle, default_inference=gs.MCMCSpec(gs.IWLSKernel))

loc = gam.AdditivePredictor("loc")
scale = gam.AdditivePredictor("scale", inv_link=jnp.exp)

loc_smooth = tb.ps("times", k=20, prefix="loc.")
scale_smooth = tb.ps("times", k=20, prefix="scale.")

loc += loc_smooth
scale += scale_smooth

response_dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale)
y = lsl.Var.new_obs(mcycle["accel"], response_dist, name="y")
model = lsl.Model(y)

Metropolis-in-Gibbs#

First, we run the model with the inference specifications attached during model construction. This gives a Metropolis-in-Gibbs sampling scheme with IWLS kernels for the regression coefficients (\(\boldsymbol{\beta}\)) and Gibbs kernels for the smoothing variances (\(\tau^2\)) of the splines.

iwls_results = gs.LieselMCMC(model).run_for_epochs(
    seed=1, num_chains=4, adaptation=1000, posterior=10000, posterior_thinning=10, show_progress=False
)
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{loc.ps(times)}$', '$\\beta_{scale.ps(times)}$', '$\\tau_{loc.ps(times)}^2$', '$\\beta_{0,loc}$', '$\\tau_{scale.ps(times)}^2$', '$\\beta_{0,scale}$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Finished warmup

Clearly, the performance of the sampler could be better, especially for the intercept of the mean. The corresponding chain exhibits a very strong autocorrelation.

gs.Summary(iwls_results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,loc}$ () kernel_03 -25.208 1.965 -28.423 -25.268 -21.903 4000 120.032 385.818 1.028
$\beta_{0,scale}$ () kernel_05 2.724 0.074 2.604 2.723 2.850 4000 949.963 2440.940 1.006
$\beta_{loc.ps(times)}$ (0,) kernel_00 -2.524 10.793 -20.055 -2.508 15.091 4000 3090.440 3276.193 1.000
(1,) kernel_00 -11.400 10.177 -29.560 -10.567 4.041 4000 2135.275 3303.146 1.001
(2,) kernel_00 1.511 9.359 -13.859 1.501 17.137 4000 2723.351 3657.486 1.000
(3,) kernel_00 -3.298 9.018 -17.980 -3.070 11.133 4000 3003.180 3224.344 1.001
(4,) kernel_00 -9.468 8.766 -24.269 -9.294 4.662 4000 2946.688 3561.010 1.002
(5,) kernel_00 -10.234 8.355 -24.368 -9.992 3.258 4000 3195.905 3687.709 1.001
(6,) kernel_00 -0.227 7.849 -13.745 -0.047 12.472 4000 3188.969 3367.629 1.002
(7,) kernel_00 0.419 7.437 -11.654 0.466 12.509 4000 3357.348 3797.740 1.000
(8,) kernel_00 10.635 6.672 -0.471 10.726 21.610 4000 3461.824 3833.617 1.001
(9,) kernel_00 -15.581 5.801 -25.258 -15.562 -6.206 4000 2838.619 3115.745 1.001
(10,) kernel_00 -7.541 4.845 -15.574 -7.503 0.180 4000 2585.065 3339.469 1.001
(11,) kernel_00 -23.790 4.381 -30.970 -23.769 -16.681 4000 3224.815 3611.463 1.000
(12,) kernel_00 9.344 3.253 4.176 9.314 14.688 4000 3408.073 3576.336 1.001
(13,) kernel_00 -10.210 2.572 -14.541 -10.185 -6.134 4000 3188.634 3555.224 1.001
(14,) kernel_00 12.244 1.879 9.131 12.266 15.183 4000 3281.223 3167.412 1.000
(15,) kernel_00 2.256 1.234 0.241 2.267 4.175 4000 2110.177 2477.466 1.002
(16,) kernel_00 -3.140 0.642 -4.201 -3.127 -2.140 4000 3402.169 3136.336 1.001
(17,) kernel_00 0.913 0.241 0.514 0.918 1.291 4000 1220.134 3013.988 1.002
(18,) kernel_00 2.982 0.922 1.511 3.001 4.396 4000 3292.914 2480.931 1.001
$\beta_{scale.ps(times)}$ (0,) kernel_01 0.005 0.141 -0.217 0.004 0.237 4000 1005.204 1235.971 1.005
(1,) kernel_01 0.026 0.147 -0.195 0.017 0.264 4000 978.142 1155.381 1.004
(2,) kernel_01 0.004 0.138 -0.221 -0.001 0.225 4000 909.014 1177.690 1.005
(3,) kernel_01 0.030 0.144 -0.189 0.022 0.269 4000 943.784 855.745 1.008
(4,) kernel_01 0.031 0.143 -0.191 0.023 0.272 4000 972.728 1371.227 1.005
(5,) kernel_01 -0.037 0.142 -0.277 -0.030 0.176 4000 775.126 997.943 1.008
(6,) kernel_01 0.058 0.145 -0.147 0.043 0.327 4000 711.239 1034.814 1.009
(7,) kernel_01 0.011 0.129 -0.202 0.008 0.225 4000 905.230 1052.320 1.007
(8,) kernel_01 0.102 0.143 -0.092 0.080 0.368 4000 401.649 851.800 1.009
(9,) kernel_01 -0.092 0.128 -0.326 -0.079 0.093 4000 743.720 1043.326 1.009
(10,) kernel_01 0.115 0.117 -0.055 0.107 0.317 4000 626.941 1213.472 1.007
(11,) kernel_01 0.010 0.112 -0.175 0.013 0.189 4000 818.491 1079.558 1.002
(12,) kernel_01 0.199 0.130 0.024 0.180 0.433 4000 309.252 739.200 1.017
(13,) kernel_01 0.136 0.101 -0.015 0.129 0.311 4000 502.829 1151.172 1.005
(14,) kernel_01 -0.079 0.086 -0.231 -0.068 0.043 4000 378.195 797.000 1.011
(15,) kernel_01 0.042 0.056 -0.041 0.038 0.140 4000 579.289 1196.250 1.008
(16,) kernel_01 0.026 0.034 -0.035 0.029 0.074 4000 392.207 889.776 1.013
(17,) kernel_01 -0.063 0.013 -0.082 -0.064 -0.042 4000 624.398 1136.041 1.005
(18,) kernel_01 0.111 0.050 0.027 0.111 0.191 4000 518.138 983.889 1.010
$\tau_{loc.ps(times)}^2$ () kernel_02 138.693 68.000 62.523 122.988 270.299 4000 2808.853 3448.898 1.001
$\tau_{scale.ps(times)}^2$ () kernel_04 0.021 0.020 0.003 0.015 0.060 4000 209.585 560.265 1.022

Acceptance probabilities:

acceptance_probability position_moved
kernel positions phase
kernel_00 $\beta_{loc.ps(times)}$ posterior 0.865 0.865
warmup 0.794 0.796
kernel_01 $\beta_{scale.ps(times)}$ posterior 0.868 0.868
warmup 0.793 0.795
kernel_02 $\tau_{loc.ps(times)}^2$ posterior 1.000 1.000
warmup 1.000 1.000
kernel_03 $\beta_{0,loc}$ posterior 0.922 0.922
warmup 0.922 0.919
kernel_04 $\tau_{scale.ps(times)}^2$ posterior 1.000 1.000
warmup 1.000 1.000
kernel_05 $\beta_{0,scale}$ posterior 0.906 0.905
warmup 0.912 0.912
gs.plot_trace(iwls_results)

../../_images/iwls-traces-output-1.png

To confirm that the chains have converged to reasonable values, we plot the posterior mean of the location predictor together with a 90% credible interval:

def plot_loc_estimate(results, model, title):
    samples = results.get_posterior_samples()
    loc_samples = model.vars["loc"].predict(samples)
    loc_summary = gs.SamplesSummary.from_array(
        loc_samples,
        name="loc",
        which=["mean", "quantiles"],
    )
    loc_summary_df = loc_summary.to_dataframe().reset_index()

    loc_summary_df["times"] = mcycle["times"].to_numpy()
    plot_data = (
        loc_summary_df[["times", "mean", "q_0.05", "q_0.95"]]
        .groupby("times", as_index=False)
        .mean()
        .sort_values("times")
    )

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.fill_between(
        plot_data["times"],
        plot_data["q_0.05"],
        plot_data["q_0.95"],
        color=sns.color_palette()[1],
        alpha=0.25,
        label="90% credible interval",
    )
    sns.lineplot(
        data=plot_data,
        x="times",
        y="mean",
        color=sns.color_palette()[1],
        linewidth=2,
        label="posterior mean",
        ax=ax,
    )
    sns.scatterplot(
        data=mcycle,
        x="times",
        y="accel",
        color="0.25",
        s=25,
        ax=ax,
        label="observed data",
    )
    ax.set(xlabel="time after impact", ylabel="acceleration", title=title)
    plt.show()


plot_loc_estimate(iwls_results, model, "Estimated mean function (IWLS/Gibbs)")

../../_images/iwls-spline-output-1.png

NUTS sampler#

As an alternative, we use NUTS kernels for the spline-specific parameter blocks. The helper below copies the model graph, log-transforms the smoothing variances by bijecting them with an exponential bijector, and assigns one NUTS kernel group per additive term.

def strategy_term_blocked(
    model: lsl.Model, predictors: list[str], kernel_constructor, **kwargs
):
    model = model.copy()
    for k, v in model.parameters.items():
        if "tau" in k:
            v.biject(tfb.Exp(), inference="drop")

    for predictor_name in predictors:
        predictor = model.vars[predictor_name]
        if predictor.intercept:
            predictor.intercept.inference = gs.MCMCSpec(
                kernel_constructor, kernel_kwargs=kwargs
            )

        for term in predictor.terms.values():
            for param in model.parental_submodel(term).parameters.values():
                model.parameters[param.name].inference = gs.MCMCSpec(
                    kernel_constructor, kernel_group=term.name, kernel_kwargs=kwargs
                )

    return model
nuts_model = strategy_term_blocked(model, ["loc", "scale"], gs.NUTSKernel)

The resulting model contains transformed smoothing variances on the unconstrained log scale. Here is the transformed model graph:

nuts_model.plot()

../../_images/transformed-graph-output-1.png

Now we can run the sampler from the MCMCSpec objects stored in the model. In complex models like this one, it can be beneficial to sample the parameters of each additive term in a separate NUTS block.

nuts_results = gs.LieselMCMC(nuts_model).run_for_epochs(
    seed=1, num_chains=4, adaptation=1000, posterior=1000, show_progress=False
)
liesel.goose.builder - WARNING - No jitter functions provided for position keys '$\\beta_{loc.ps(times)}$', 'h($\\tau_{loc.ps(times)}^2$)', '$\\beta_{scale.ps(times)}$', 'h($\\tau_{scale.ps(times)}^2$)', '$\\beta_{0,loc}$', '$\\beta_{0,scale}$'. The initial values for these keys won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done
liesel.goose.engine - INFO - Finished warmup

The blocked NUTS strategy overall seems to do a good job and can yield higher effective sample sizes than the IWLS sampler, especially for the spline coefficients of the scale model.

gs.Summary(nuts_results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
$\beta_{0,loc}$ () kernel_02 -25.059 2.030 -28.648 -24.870 -22.147 4000 15.520 32.866 1.187
$\beta_{0,scale}$ () kernel_03 2.723 0.075 2.605 2.720 2.851 4000 724.755 1476.686 1.010
$\beta_{loc.ps(times)}$ (0,) kernel_00 -2.691 10.873 -21.098 -2.417 14.901 4000 3172.592 2247.690 1.002
(1,) kernel_00 -11.200 10.141 -28.765 -10.610 4.500 4000 3030.145 2713.784 1.001
(2,) kernel_00 1.270 9.206 -14.030 1.372 16.822 4000 2862.872 2738.921 1.000
(3,) kernel_00 -3.303 9.319 -18.766 -3.385 11.490 4000 3138.678 2631.457 1.004
(4,) kernel_00 -9.211 9.045 -24.604 -9.011 5.188 4000 2791.808 2890.167 1.000
(5,) kernel_00 -10.299 8.634 -24.580 -10.163 3.489 4000 2653.144 2468.135 1.001
(6,) kernel_00 -0.358 7.721 -13.335 -0.316 12.184 4000 2836.645 2753.970 1.000
(7,) kernel_00 0.310 7.274 -11.583 0.183 12.060 4000 2184.684 2109.061 1.003
(8,) kernel_00 10.700 6.754 -0.415 10.721 21.872 4000 2166.268 2440.807 1.002
(9,) kernel_00 -15.616 5.868 -25.287 -15.578 -6.413 4000 1839.594 2270.726 1.002
(10,) kernel_00 -7.430 4.891 -15.548 -7.313 0.655 4000 1608.507 1894.589 1.003
(11,) kernel_00 -23.869 4.395 -31.197 -23.822 -16.657 4000 1557.058 2153.137 1.002
(12,) kernel_00 9.163 3.302 3.914 9.056 14.694 4000 1368.387 1572.625 1.002
(13,) kernel_00 -10.188 2.546 -14.477 -10.144 -6.160 4000 1493.362 1663.695 1.001
(14,) kernel_00 12.304 1.850 9.305 12.302 15.380 4000 1466.403 1652.025 1.001
(15,) kernel_00 2.301 1.207 0.335 2.313 4.246 4000 525.186 1066.279 1.011
(16,) kernel_00 -3.127 0.635 -4.149 -3.109 -2.140 4000 1243.732 1208.124 1.002
(17,) kernel_00 0.909 0.237 0.517 0.920 1.275 4000 358.475 1228.922 1.019
(18,) kernel_00 3.025 0.909 1.563 3.063 4.408 4000 1136.638 1239.738 1.003
$\beta_{scale.ps(times)}$ (0,) kernel_01 -0.003 0.136 -0.225 -0.002 0.214 4000 4443.099 2264.366 1.004
(1,) kernel_01 0.027 0.140 -0.187 0.017 0.265 4000 4525.058 2201.980 1.006
(2,) kernel_01 0.009 0.142 -0.216 0.007 0.237 4000 4704.091 2311.857 1.003
(3,) kernel_01 0.029 0.139 -0.182 0.021 0.260 4000 4996.028 2397.078 1.004
(4,) kernel_01 0.029 0.142 -0.185 0.021 0.255 4000 3931.107 1988.719 1.004
(5,) kernel_01 -0.031 0.144 -0.283 -0.025 0.193 4000 4520.747 2062.157 1.002
(6,) kernel_01 0.056 0.142 -0.150 0.043 0.302 4000 3017.528 1983.617 1.002
(7,) kernel_01 0.011 0.128 -0.193 0.010 0.223 4000 4661.542 2274.334 1.003
(8,) kernel_01 0.091 0.143 -0.106 0.073 0.346 4000 1973.721 2407.005 1.004
(9,) kernel_01 -0.089 0.130 -0.317 -0.078 0.102 4000 2292.575 2297.788 1.003
(10,) kernel_01 0.115 0.119 -0.059 0.106 0.326 4000 2682.038 2554.996 1.001
(11,) kernel_01 0.008 0.113 -0.176 0.012 0.182 4000 3215.945 2081.179 1.001
(12,) kernel_01 0.204 0.124 0.028 0.191 0.427 4000 887.550 1979.028 1.007
(13,) kernel_01 0.136 0.099 -0.009 0.128 0.311 4000 1220.217 2118.627 1.006
(14,) kernel_01 -0.083 0.085 -0.235 -0.075 0.042 4000 1018.994 1940.152 1.004
(15,) kernel_01 0.041 0.058 -0.047 0.036 0.147 4000 1154.360 1768.703 1.004
(16,) kernel_01 0.024 0.033 -0.032 0.026 0.076 4000 1211.201 1874.812 1.004
(17,) kernel_01 -0.063 0.013 -0.083 -0.064 -0.041 4000 1241.131 1596.560 1.003
(18,) kernel_01 0.109 0.050 0.032 0.107 0.193 4000 1358.989 1609.352 1.002
h($\tau_{loc.ps(times)}^2$) () kernel_00 4.832 0.427 4.156 4.819 5.555 4000 1907.703 2540.741 1.001
h($\tau_{scale.ps(times)}^2$) () kernel_01 -4.208 0.847 -5.668 -4.171 -2.850 4000 385.967 730.356 1.014

Acceptance probabilities:

acceptance_probability position_moved
kernel positions phase
kernel_00 $\beta_{loc.ps(times)}$, h($\tau_{loc.ps(times)}^2$) posterior 0.882 NaN
warmup 0.793 NaN
kernel_01 $\beta_{scale.ps(times)}$, h($\tau_{scale.ps(times)}^2$) posterior 0.877 NaN
warmup 0.795 NaN
kernel_02 $\beta_{0,loc}$ posterior 0.858 NaN
warmup 0.792 NaN
kernel_03 $\beta_{0,scale}$ posterior 0.878 NaN
warmup 0.793 NaN

Error summary:

count sample_size sample_size_total relative
kernel positions error_code error_msg phase
kernel_00 $\beta_{loc.ps(times)}$, h($\tau_{loc.ps(times)}^2$) 1 divergent transition warmup 369 4000 4000 0.092
posterior 26 4000 4000 0.007
2 maximum tree depth warmup 389 4000 4000 0.097
posterior 0 4000 4000 0.000
kernel_01 $\beta_{scale.ps(times)}$, h($\tau_{scale.ps(times)}^2$) 1 divergent transition warmup 277 4000 4000 0.069
posterior 0 4000 4000 0.000
2 maximum tree depth warmup 1 4000 4000 0.000
posterior 0 4000 4000 0.000
kernel_02 $\beta_{0,loc}$ 1 divergent transition warmup 59 4000 4000 0.015
posterior 0 4000 4000 0.000
kernel_03 $\beta_{0,scale}$ 1 divergent transition warmup 47 4000 4000 0.012
posterior 0 4000 4000 0.000
gs.plot_trace(nuts_results)

../../_images/nuts-traces-output-1.png

Again, here is the posterior mean function with a 90% credible interval:

plot_loc_estimate(nuts_results, nuts_model, "Estimated mean function (NUTS)")

../../_images/nuts-spline-output-1.png