
# Comparing samplers

In this tutorial, we are comparing two different sampling schemes on the
`mcycle` dataset with a Gaussian location-scale regression model and two
splines for the mean and the standard deviation. The `mcycle` dataset is
a “data frame giving a series of measurements of head acceleration in a
simulated motorcycle accident, used to test crash helmets” (from the
help page). It contains the following two variables:

- `times`: in milliseconds after impact
- `accel`: in g

We start off in R by loading the dataset and setting up the model with
the `rliesel::liesel()` function.

``` r
library(MASS)
library(rliesel)
```

    Please make sure you are using a virtual or conda environment with Liesel installed, e.g. using `reticulate::use_virtualenv()` or `reticulate::use_condaenv()`. See `vignette("versions", "reticulate")`.

    After setting the environment, check if the installed versions of RLiesel and Liesel are compatible with `check_liesel_version()`.

``` r
data(mcycle)
with(mcycle, plot(times, accel))
```

![](04-mcycle_files/figure-commonmark/model-1.png)

``` r
model <- liesel(
  response = mcycle$accel,
  distribution = "Normal",
  predictors = list(
    loc = predictor(~ s(times)),
    scale = predictor(~ s(times), inverse_link = "Exp")
  ),
  data = mcycle
)
```

    Installed Liesel version 0.2.4 is compatible.

## Metropolis-in-Gibbs

First, we try a Metropolis-in-Gibbs sampling scheme with IWLS kernels
for the regression coefficients ($\boldsymbol{\beta}$) and Gibbs kernels
for the smoothing parameters ($\tau^2$) of the splines.

``` python
import liesel.model as lsl

model = r.model

builder = lsl.dist_reg_mcmc(model, seed=42, num_chains=4)
builder.set_duration(warmup_duration=5000, posterior_duration=1000)

engine = builder.build()
```

    liesel.goose.engine - INFO - Initializing kernels...
    liesel.goose.engine - INFO - Done

``` python
engine.sample_all_epochs()
```

    liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 75 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 1, 1, 0 / 75 transitions
    liesel.goose.engine - WARNING - Errors per chain for kernel_02: 1, 1, 1, 1 / 75 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 1, 0, 0 / 25 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 0, 1, 0 / 50 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 0, 0, 1 / 100 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 200 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 0, 0, 1 / 200 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 400 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 1, 1, 1 / 400 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 800 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 1, 1, 1 / 800 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 3300 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 1, 0, 1, 0 / 3300 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 50 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_01: 0, 1, 0, 0 / 50 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Finished warmup
    liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 1000 transitions, 25 jitted together
    liesel.goose.engine - INFO - Finished epoch

Clearly, the performance of the sampler could be better, especially for
the intercept of the mean. The corresponding chain exhibits a very
strong autocorrelation.

``` python
import liesel.goose as gs

results = engine.get_results()
gs.Summary(results)
```

<p>
<strong>Parameter summary:</strong>
</p>
<table border="0" class="dataframe">
<thead>
<tr style="text-align: right;">
<th>
</th>
<th>
</th>
<th>
kernel
</th>
<th>
mean
</th>
<th>
sd
</th>
<th>
q_0.05
</th>
<th>
q_0.5
</th>
<th>
q_0.95
</th>
<th>
sample_size
</th>
<th>
ess_bulk
</th>
<th>
ess_tail
</th>
<th>
rhat
</th>
</tr>
<tr>
<th>
parameter
</th>
<th>
index
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
</tr>
</thead>
<tbody>
<tr>
<th rowspan="9" valign="top">
loc_np0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_04
</td>
<td>
-123.925
</td>
<td>
252.403
</td>
<td>
-529.979
</td>
<td>
-136.899
</td>
<td>
310.361
</td>
<td>
4000
</td>
<td>
32.770
</td>
<td>
283.680
</td>
<td>
1.093
</td>
</tr>
<tr>
<th>
(1,)
</th>
<td>
kernel_04
</td>
<td>
-1443.933
</td>
<td>
242.040
</td>
<td>
-1857.528
</td>
<td>
-1438.059
</td>
<td>
-1052.284
</td>
<td>
4000
</td>
<td>
154.026
</td>
<td>
326.176
</td>
<td>
1.025
</td>
</tr>
<tr>
<th>
(2,)
</th>
<td>
kernel_04
</td>
<td>
-713.112
</td>
<td>
170.964
</td>
<td>
-996.168
</td>
<td>
-709.703
</td>
<td>
-427.398
</td>
<td>
4000
</td>
<td>
157.377
</td>
<td>
733.196
</td>
<td>
1.027
</td>
</tr>
<tr>
<th>
(3,)
</th>
<td>
kernel_04
</td>
<td>
-566.761
</td>
<td>
106.593
</td>
<td>
-742.514
</td>
<td>
-566.403
</td>
<td>
-392.259
</td>
<td>
4000
</td>
<td>
345.613
</td>
<td>
699.904
</td>
<td>
1.021
</td>
</tr>
<tr>
<th>
(4,)
</th>
<td>
kernel_04
</td>
<td>
-1127.952
</td>
<td>
93.165
</td>
<td>
-1283.961
</td>
<td>
-1125.748
</td>
<td>
-978.007
</td>
<td>
4000
</td>
<td>
287.206
</td>
<td>
523.644
</td>
<td>
1.017
</td>
</tr>
<tr>
<th>
(5,)
</th>
<td>
kernel_04
</td>
<td>
-59.236
</td>
<td>
32.689
</td>
<td>
-112.265
</td>
<td>
-59.756
</td>
<td>
-5.968
</td>
<td>
4000
</td>
<td>
174.249
</td>
<td>
355.672
</td>
<td>
1.018
</td>
</tr>
<tr>
<th>
(6,)
</th>
<td>
kernel_04
</td>
<td>
-213.544
</td>
<td>
20.962
</td>
<td>
-246.796
</td>
<td>
-214.359
</td>
<td>
-179.102
</td>
<td>
4000
</td>
<td>
143.348
</td>
<td>
633.100
</td>
<td>
1.038
</td>
</tr>
<tr>
<th>
(7,)
</th>
<td>
kernel_04
</td>
<td>
115.530
</td>
<td>
71.315
</td>
<td>
12.799
</td>
<td>
108.984
</td>
<td>
239.090
</td>
<td>
4000
</td>
<td>
154.495
</td>
<td>
199.531
</td>
<td>
1.024
</td>
</tr>
<tr>
<th>
(8,)
</th>
<td>
kernel_04
</td>
<td>
30.199
</td>
<td>
18.463
</td>
<td>
3.722
</td>
<td>
28.654
</td>
<td>
61.787
</td>
<td>
4000
</td>
<td>
143.278
</td>
<td>
197.396
</td>
<td>
1.036
</td>
</tr>
<tr>
<th>
loc_np0_tau2
</th>
<th>
()
</th>
<td>
kernel_03
</td>
<td>
735262.000
</td>
<td>
559525.062
</td>
<td>
253877.688
</td>
<td>
592855.656
</td>
<td>
1666969.462
</td>
<td>
4000
</td>
<td>
1352.310
</td>
<td>
2737.206
</td>
<td>
1.003
</td>
</tr>
<tr>
<th>
loc_p0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_05
</td>
<td>
-23.949
</td>
<td>
1.557
</td>
<td>
-26.545
</td>
<td>
-23.892
</td>
<td>
-21.367
</td>
<td>
4000
</td>
<td>
13.121
</td>
<td>
20.016
</td>
<td>
1.246
</td>
</tr>
<tr>
<th rowspan="9" valign="top">
scale_np0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_01
</td>
<td>
8.056
</td>
<td>
10.744
</td>
<td>
-6.509
</td>
<td>
5.540
</td>
<td>
27.736
</td>
<td>
4000
</td>
<td>
37.591
</td>
<td>
150.933
</td>
<td>
1.103
</td>
</tr>
<tr>
<th>
(1,)
</th>
<td>
kernel_01
</td>
<td>
-1.407
</td>
<td>
5.931
</td>
<td>
-11.598
</td>
<td>
-1.060
</td>
<td>
8.103
</td>
<td>
4000
</td>
<td>
210.887
</td>
<td>
401.406
</td>
<td>
1.014
</td>
</tr>
<tr>
<th>
(2,)
</th>
<td>
kernel_01
</td>
<td>
-17.430
</td>
<td>
9.659
</td>
<td>
-33.542
</td>
<td>
-17.559
</td>
<td>
-2.789
</td>
<td>
4000
</td>
<td>
28.120
</td>
<td>
129.904
</td>
<td>
1.123
</td>
</tr>
<tr>
<th>
(3,)
</th>
<td>
kernel_01
</td>
<td>
10.911
</td>
<td>
5.240
</td>
<td>
2.846
</td>
<td>
10.915
</td>
<td>
19.429
</td>
<td>
4000
</td>
<td>
39.408
</td>
<td>
147.850
</td>
<td>
1.088
</td>
</tr>
<tr>
<th>
(4,)
</th>
<td>
kernel_01
</td>
<td>
2.692
</td>
<td>
4.240
</td>
<td>
-4.188
</td>
<td>
2.627
</td>
<td>
9.817
</td>
<td>
4000
</td>
<td>
56.953
</td>
<td>
294.239
</td>
<td>
1.073
</td>
</tr>
<tr>
<th>
(5,)
</th>
<td>
kernel_01
</td>
<td>
4.076
</td>
<td>
1.937
</td>
<td>
1.095
</td>
<td>
3.922
</td>
<td>
7.346
</td>
<td>
4000
</td>
<td>
24.444
</td>
<td>
117.743
</td>
<td>
1.118
</td>
</tr>
<tr>
<th>
(6,)
</th>
<td>
kernel_01
</td>
<td>
-0.588
</td>
<td>
3.199
</td>
<td>
-6.105
</td>
<td>
-0.386
</td>
<td>
4.074
</td>
<td>
4000
</td>
<td>
21.601
</td>
<td>
122.160
</td>
<td>
1.138
</td>
</tr>
<tr>
<th>
(7,)
</th>
<td>
kernel_01
</td>
<td>
-0.311
</td>
<td>
3.659
</td>
<td>
-5.933
</td>
<td>
-0.648
</td>
<td>
6.267
</td>
<td>
4000
</td>
<td>
92.055
</td>
<td>
165.632
</td>
<td>
1.033
</td>
</tr>
<tr>
<th>
(8,)
</th>
<td>
kernel_01
</td>
<td>
-1.197
</td>
<td>
1.925
</td>
<td>
-4.500
</td>
<td>
-1.071
</td>
<td>
1.700
</td>
<td>
4000
</td>
<td>
26.872
</td>
<td>
131.637
</td>
<td>
1.122
</td>
</tr>
<tr>
<th>
scale_np0_tau2
</th>
<th>
()
</th>
<td>
kernel_00
</td>
<td>
137.908
</td>
<td>
162.827
</td>
<td>
10.762
</td>
<td>
85.152
</td>
<td>
428.775
</td>
<td>
4000
</td>
<td>
27.860
</td>
<td>
179.005
</td>
<td>
1.113
</td>
</tr>
<tr>
<th>
scale_p0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_02
</td>
<td>
2.763
</td>
<td>
0.070
</td>
<td>
2.653
</td>
<td>
2.762
</td>
<td>
2.879
</td>
<td>
4000
</td>
<td>
163.264
</td>
<td>
926.926
</td>
<td>
1.024
</td>
</tr>
</tbody>
</table>
<p>
<strong>Error summary:</strong>
</p>
<table border="0" class="dataframe">
<thead>
<tr style="text-align: right;">
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
count
</th>
<th>
relative
</th>
</tr>
<tr>
<th>
kernel
</th>
<th>
error_code
</th>
<th>
error_msg
</th>
<th>
phase
</th>
<th>
</th>
<th>
</th>
</tr>
</thead>
<tbody>
<tr>
<th rowspan="2" valign="top">
kernel_01
</th>
<th rowspan="2" valign="top">
90
</th>
<th rowspan="2" valign="top">
nan acceptance prob
</th>
<th>
warmup
</th>
<td>
17
</td>
<td>
0.001
</td>
</tr>
<tr>
<th>
posterior
</th>
<td>
0
</td>
<td>
0.000
</td>
</tr>
<tr>
<th rowspan="2" valign="top">
kernel_02
</th>
<th rowspan="2" valign="top">
90
</th>
<th rowspan="2" valign="top">
nan acceptance prob
</th>
<th>
warmup
</th>
<td>
4
</td>
<td>
0.000
</td>
</tr>
<tr>
<th>
posterior
</th>
<td>
0
</td>
<td>
0.000
</td>
</tr>
</tbody>
</table>

``` python
fig = gs.plot_trace(results, "loc_p0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/iwls-traces-1.png)

``` python
fig = gs.plot_trace(results, "loc_np0_tau2")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/iwls-traces-2.png)

``` python
fig = gs.plot_trace(results, "loc_np0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/iwls-traces-3.png)

``` python
fig = gs.plot_trace(results, "scale_p0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/iwls-traces-4.png)

``` python
fig = gs.plot_trace(results, "scale_np0_tau2")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/iwls-traces-5.png)

``` python
fig = gs.plot_trace(results, "scale_np0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/iwls-traces-6.png)

To confirm that the chains have converged to reasonable values, here is
a plot of the estimated mean function:

``` python
summary = gs.Summary(results).to_dataframe().reset_index()
```

``` r
library(dplyr)
```


    Attaching package: 'dplyr'

    The following object is masked from 'package:MASS':

        select

    The following objects are masked from 'package:stats':

        filter, lag

    The following objects are masked from 'package:base':

        intersect, setdiff, setequal, union

``` r
library(ggplot2)
library(reticulate)

summary <- py$summary

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- py_to_r(model$vars["loc_np0_X"]$value)
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()
```

![](04-mcycle_files/figure-commonmark/iwls-spline-13.png)

## NUTS sampler

As an alternative, we try a NUTS kernel which samples all model
parameters (regression coefficients and smoothing parameters) in one
block. To do so, we first need to log-transform the smoothing
parameters. This is the model graph before the transformation:

``` python
lsl.plot_vars(model)
```

![](04-mcycle_files/figure-commonmark/untransformed-graph-1.png)

Before transforming the smoothing parameters with the
`lsl.transform_parameter()` function, we first need to copy all model
nodes. Once this is done, we need to update the output nodes of the
smoothing parameters and rebuild the model. There are two additional
nodes in the new model graph.

``` python
import tensorflow_probability.substrates.jax.bijectors as tfb

nodes, _vars = model.pop_nodes_and_vars()

gb = lsl.GraphBuilder()
gb.add(_vars["response"])
```

    GraphBuilder(0 nodes, 1 vars)

``` python
_ = gb.transform(_vars["loc_np0_tau2"], tfb.Exp)
_ = gb.transform(_vars["scale_np0_tau2"], tfb.Exp)
model = gb.build_model()
lsl.plot_vars(model)
```

![](04-mcycle_files/figure-commonmark/transformed-graph-3.png)

Now we can set up the NUTS sampler, which is straightforward because we
are using only one kernel.

``` python
parameters = [name for name, var in model.vars.items() if var.parameter]

builder = gs.EngineBuilder(seed=42, num_chains=4)

builder.set_model(lsl.GooseModel(model))
builder.add_kernel(gs.NUTSKernel(parameters))
builder.set_initial_values(model.state)

builder.set_duration(warmup_duration=5000, posterior_duration=1000)

engine = builder.build()
```

    liesel.goose.engine - INFO - Initializing kernels...
    liesel.goose.engine - INFO - Done

``` python
engine.sample_all_epochs()
```

    liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 75 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 45, 31, 51, 61 / 75 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 9, 21, 13, 8 / 25 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 40, 50, 49, 46 / 50 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 93, 93, 97, 91 / 100 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 200 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 196, 196, 192, 191 / 200 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 400 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 388, 371, 382, 381 / 400 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 800 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 789, 760, 772, 768 / 800 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 3300 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 3138, 3198, 3213, 3213 / 3300 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 50 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 46, 47, 48, 49 / 50 transitions
    liesel.goose.engine - INFO - Finished epoch
    liesel.goose.engine - INFO - Finished warmup
    liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 1000 transitions, 25 jitted together
    liesel.goose.engine - WARNING - Errors per chain for kernel_00: 999, 999, 998, 998 / 1000 transitions
    liesel.goose.engine - INFO - Finished epoch

The results are mixed. On the one hand, the NUTS sampler performs much
better on the intercepts (for both the mean and the standard deviation),
but on the other hand, the Metropolis-in-Gibbs sampler with the IWLS
kernels seems to work better for the spline coefficients.

``` python
results = engine.get_results()
gs.Summary(results)
```

<p>
<strong>Parameter summary:</strong>
</p>
<table border="0" class="dataframe">
<thead>
<tr style="text-align: right;">
<th>
</th>
<th>
</th>
<th>
kernel
</th>
<th>
mean
</th>
<th>
sd
</th>
<th>
q_0.05
</th>
<th>
q_0.5
</th>
<th>
q_0.95
</th>
<th>
sample_size
</th>
<th>
ess_bulk
</th>
<th>
ess_tail
</th>
<th>
rhat
</th>
</tr>
<tr>
<th>
parameter
</th>
<th>
index
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
</tr>
</thead>
<tbody>
<tr>
<th rowspan="9" valign="top">
loc_np0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_00
</td>
<td>
69.469
</td>
<td>
94.074
</td>
<td>
-19.308
</td>
<td>
24.245
</td>
<td>
235.016
</td>
<td>
4000
</td>
<td>
4.338
</td>
<td>
11.485
</td>
<td>
3.937
</td>
</tr>
<tr>
<th>
(1,)
</th>
<td>
kernel_00
</td>
<td>
-34.872
</td>
<td>
29.647
</td>
<td>
-79.634
</td>
<td>
-31.955
</td>
<td>
12.604
</td>
<td>
4000
</td>
<td>
4.462
</td>
<td>
10.910
</td>
<td>
3.398
</td>
</tr>
<tr>
<th>
(2,)
</th>
<td>
kernel_00
</td>
<td>
-225.643
</td>
<td>
173.152
</td>
<td>
-489.674
</td>
<td>
-190.451
</td>
<td>
-2.248
</td>
<td>
4000
</td>
<td>
4.560
</td>
<td>
12.395
</td>
<td>
3.057
</td>
</tr>
<tr>
<th>
(3,)
</th>
<td>
kernel_00
</td>
<td>
-29.091
</td>
<td>
129.928
</td>
<td>
-242.193
</td>
<td>
-10.036
</td>
<td>
216.950
</td>
<td>
4000
</td>
<td>
4.723
</td>
<td>
11.857
</td>
<td>
2.639
</td>
</tr>
<tr>
<th>
(4,)
</th>
<td>
kernel_00
</td>
<td>
-712.810
</td>
<td>
135.987
</td>
<td>
-969.966
</td>
<td>
-716.210
</td>
<td>
-480.268
</td>
<td>
4000
</td>
<td>
4.894
</td>
<td>
11.407
</td>
<td>
2.437
</td>
</tr>
<tr>
<th>
(5,)
</th>
<td>
kernel_00
</td>
<td>
-82.669
</td>
<td>
42.168
</td>
<td>
-153.224
</td>
<td>
-73.919
</td>
<td>
-23.744
</td>
<td>
4000
</td>
<td>
5.571
</td>
<td>
12.573
</td>
<td>
1.928
</td>
</tr>
<tr>
<th>
(6,)
</th>
<td>
kernel_00
</td>
<td>
-130.846
</td>
<td>
30.345
</td>
<td>
-176.981
</td>
<td>
-132.647
</td>
<td>
-81.529
</td>
<td>
4000
</td>
<td>
5.781
</td>
<td>
18.869
</td>
<td>
1.858
</td>
</tr>
<tr>
<th>
(7,)
</th>
<td>
kernel_00
</td>
<td>
29.494
</td>
<td>
41.749
</td>
<td>
-25.851
</td>
<td>
29.180
</td>
<td>
105.536
</td>
<td>
4000
</td>
<td>
4.729
</td>
<td>
11.952
</td>
<td>
2.677
</td>
</tr>
<tr>
<th>
(8,)
</th>
<td>
kernel_00
</td>
<td>
20.231
</td>
<td>
9.427
</td>
<td>
5.199
</td>
<td>
20.017
</td>
<td>
38.063
</td>
<td>
4000
</td>
<td>
5.258
</td>
<td>
13.983
</td>
<td>
2.151
</td>
</tr>
<tr>
<th>
loc_np0_tau2_transformed
</th>
<th>
()
</th>
<td>
kernel_00
</td>
<td>
11.376
</td>
<td>
0.689
</td>
<td>
10.245
</td>
<td>
11.377
</td>
<td>
12.506
</td>
<td>
4000
</td>
<td>
10.741
</td>
<td>
23.437
</td>
<td>
1.277
</td>
</tr>
<tr>
<th>
loc_p0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_00
</td>
<td>
-17.948
</td>
<td>
3.408
</td>
<td>
-23.355
</td>
<td>
-18.022
</td>
<td>
-12.076
</td>
<td>
4000
</td>
<td>
9.897
</td>
<td>
53.404
</td>
<td>
1.330
</td>
</tr>
<tr>
<th rowspan="9" valign="top">
scale_np0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_00
</td>
<td>
-8.401
</td>
<td>
4.703
</td>
<td>
-16.181
</td>
<td>
-7.819
</td>
<td>
-2.281
</td>
<td>
4000
</td>
<td>
4.401
</td>
<td>
11.172
</td>
<td>
3.516
</td>
</tr>
<tr>
<th>
(1,)
</th>
<td>
kernel_00
</td>
<td>
4.151
</td>
<td>
6.940
</td>
<td>
-6.551
</td>
<td>
3.753
</td>
<td>
16.141
</td>
<td>
4000
</td>
<td>
25.731
</td>
<td>
81.149
</td>
<td>
1.100
</td>
</tr>
<tr>
<th>
(2,)
</th>
<td>
kernel_00
</td>
<td>
-14.150
</td>
<td>
8.351
</td>
<td>
-28.065
</td>
<td>
-13.854
</td>
<td>
-1.076
</td>
<td>
4000
</td>
<td>
9.980
</td>
<td>
64.035
</td>
<td>
1.310
</td>
</tr>
<tr>
<th>
(3,)
</th>
<td>
kernel_00
</td>
<td>
14.462
</td>
<td>
6.344
</td>
<td>
3.769
</td>
<td>
14.785
</td>
<td>
24.659
</td>
<td>
4000
</td>
<td>
10.223
</td>
<td>
34.126
</td>
<td>
1.301
</td>
</tr>
<tr>
<th>
(4,)
</th>
<td>
kernel_00
</td>
<td>
6.075
</td>
<td>
4.746
</td>
<td>
-1.435
</td>
<td>
6.085
</td>
<td>
13.940
</td>
<td>
4000
</td>
<td>
12.279
</td>
<td>
81.324
</td>
<td>
1.236
</td>
</tr>
<tr>
<th>
(5,)
</th>
<td>
kernel_00
</td>
<td>
5.980
</td>
<td>
2.517
</td>
<td>
1.815
</td>
<td>
6.063
</td>
<td>
9.952
</td>
<td>
4000
</td>
<td>
9.152
</td>
<td>
41.744
</td>
<td>
1.355
</td>
</tr>
<tr>
<th>
(6,)
</th>
<td>
kernel_00
</td>
<td>
1.213
</td>
<td>
2.341
</td>
<td>
-2.705
</td>
<td>
1.275
</td>
<td>
4.874
</td>
<td>
4000
</td>
<td>
13.211
</td>
<td>
102.812
</td>
<td>
1.209
</td>
</tr>
<tr>
<th>
(7,)
</th>
<td>
kernel_00
</td>
<td>
1.880
</td>
<td>
4.552
</td>
<td>
-4.995
</td>
<td>
1.612
</td>
<td>
9.654
</td>
<td>
4000
</td>
<td>
16.937
</td>
<td>
399.531
</td>
<td>
1.159
</td>
</tr>
<tr>
<th>
(8,)
</th>
<td>
kernel_00
</td>
<td>
-0.124
</td>
<td>
1.276
</td>
<td>
-2.398
</td>
<td>
-0.052
</td>
<td>
1.820
</td>
<td>
4000
</td>
<td>
26.935
</td>
<td>
1328.627
</td>
<td>
1.095
</td>
</tr>
<tr>
<th>
scale_np0_tau2_transformed
</th>
<th>
()
</th>
<td>
kernel_00
</td>
<td>
4.447
</td>
<td>
1.022
</td>
<td>
2.551
</td>
<td>
4.586
</td>
<td>
5.955
</td>
<td>
4000
</td>
<td>
9.088
</td>
<td>
29.637
</td>
<td>
1.357
</td>
</tr>
<tr>
<th>
scale_p0_beta
</th>
<th>
(0,)
</th>
<td>
kernel_00
</td>
<td>
3.004
</td>
<td>
0.076
</td>
<td>
2.878
</td>
<td>
3.005
</td>
<td>
3.127
</td>
<td>
4000
</td>
<td>
15.581
</td>
<td>
183.202
</td>
<td>
1.177
</td>
</tr>
</tbody>
</table>
<p>
<strong>Error summary:</strong>
</p>
<table border="0" class="dataframe">
<thead>
<tr style="text-align: right;">
<th>
</th>
<th>
</th>
<th>
</th>
<th>
</th>
<th>
count
</th>
<th>
relative
</th>
</tr>
<tr>
<th>
kernel
</th>
<th>
error_code
</th>
<th>
error_msg
</th>
<th>
phase
</th>
<th>
</th>
<th>
</th>
</tr>
</thead>
<tbody>
<tr>
<th rowspan="6" valign="top">
kernel_00
</th>
<th rowspan="2" valign="top">
1
</th>
<th rowspan="2" valign="top">
divergent transition
</th>
<th>
warmup
</th>
<td>
2586
</td>
<td>
0.129
</td>
</tr>
<tr>
<th>
posterior
</th>
<td>
0
</td>
<td>
0.000
</td>
</tr>
<tr>
<th rowspan="2" valign="top">
2
</th>
<th rowspan="2" valign="top">
maximum tree depth
</th>
<th>
warmup
</th>
<td>
15753
</td>
<td>
0.788
</td>
</tr>
<tr>
<th>
posterior
</th>
<td>
3994
</td>
<td>
0.998
</td>
</tr>
<tr>
<th rowspan="2" valign="top">
3
</th>
<th rowspan="2" valign="top">
divergent transition + maximum tree depth
</th>
<th>
warmup
</th>
<td>
797
</td>
<td>
0.040
</td>
</tr>
<tr>
<th>
posterior
</th>
<td>
0
</td>
<td>
0.000
</td>
</tr>
</tbody>
</table>

``` python
fig = gs.plot_trace(results, "loc_p0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/nuts-traces-5.png)

``` python
fig = gs.plot_trace(results, "loc_np0_tau2_transformed")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/nuts-traces-6.png)

``` python
fig = gs.plot_trace(results, "loc_np0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/nuts-traces-7.png)

``` python
fig = gs.plot_trace(results, "scale_p0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/nuts-traces-8.png)

``` python
fig = gs.plot_trace(results, "scale_np0_tau2_transformed")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/nuts-traces-9.png)

``` python
fig = gs.plot_trace(results, "scale_np0_beta")
```

    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)
    /opt/hostedtoolcache/Python/3.10.12/x64/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
      self._figure.tight_layout(*args, **kwargs)

![](04-mcycle_files/figure-commonmark/nuts-traces-10.png)

Again, here is a plot of the estimated mean function:

``` python
summary = gs.Summary(results).to_dataframe().reset_index()
```

``` r
library(dplyr)
library(ggplot2)
library(reticulate)

summary <- py$summary
model <- py$model

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- model$vars["loc_np0_X"]$value
f <- X %*% beta

beta0 <- summary %>%
  filter(variable == "loc_p0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta0 <- beta0$mean

ggplot(data.frame(times = mcycle$times, mean = beta0 + f)) +
  geom_line(aes(times, mean), color = palette()[2], size = 1) +
  geom_point(aes(times, accel), data = mcycle) +
  ggtitle("Estimated mean function") +
  theme_minimal()
```

![](04-mcycle_files/figure-commonmark/nuts-spline-17.png)
