GEV responses

GEV responses#

In this tutorial, we illustrate how to set up a distributional regression model with the generalized extreme value distribution as a response distribution. First, we simulate some data in R:

  • The location parameter (\(\mu\)) is a function of an intercept and a non-linear covariate effect.

  • The scale parameter (\(\sigma\)) is a function of an intercept and a linear effect and uses a log-link.

  • The shape or concentration parameter (\(\xi\)) is a function of an intercept and a linear effect.

After simulating the data, we can configure the model with a single call to the rliesel::liesel() function.

library(rliesel)
library(VGAM)
Loading required package: stats4

Loading required package: splines
set.seed(13)

n <- 1000

x0 <- runif(n)
x1 <- runif(n)
x2 <- runif(n)

y <- rgev(
  n,
  location = 0 + sin(2 * pi * x0),
  scale = exp(-3 + x1),
  shape = 0.1 + x2
)

plot(y)

model <- liesel(
  response = y,
  distribution = "GeneralizedExtremeValue",
  predictors = list(
    loc = predictor(~ s(x0)),
    scale = predictor(~ x1, inverse_link = "Exp"),
    concentration = predictor(~ x2)
  )
)
Did not find response 'y' in data. Using 'y' found in parent environment.

Now, we can continue in Python and use the lsl.dist_reg_mcmc() function to set up a sampling algorithm with IWLS kernels for the regression coefficients (\(\boldsymbol{\beta}\)) and a Gibbs kernel for the smoothing parameter (\(\tau^2\)) of the spline.

The support of the GEV distribution changes with the parameter values (compare Wikipedia). To ensure that the initial parameters support the observed data we set \(xi = 0.1\) and disable jittering of the the variance and regression parameters. For the latter, we supply user-defined jitter functions to lsl.dist_reg_mcmc that are essentially the identity function w.r.t. the parameter value.

import liesel.model as lsl
import jax.numpy as jnp

model = r.model

# concentration == 0.0 seems to break the sampler
model.vars["concentration_p0_beta"].value = jnp.array([0.1, 0.0])

builder = lsl.dist_reg_mcmc(model, seed=42, num_chains=4, tau2_jitter_fn=lambda key, val: val, beta_jitter_fn=lambda key, val: val)
builder.set_duration(warmup_duration=1000, posterior_duration=1000)

engine = builder.build()
engine.sample_all_epochs()
  0%|                                                  | 0/3 [00:00<?, ?chunk/s]
 33%|##############                            | 1/3 [00:04<00:09,  4.81s/chunk]
100%|##########################################| 3/3 [00:04<00:00,  1.60s/chunk]

  0%|                                                  | 0/1 [00:00<?, ?chunk/s]
100%|########################################| 1/1 [00:00<00:00, 1828.38chunk/s]

  0%|                                                  | 0/2 [00:00<?, ?chunk/s]
100%|########################################| 2/2 [00:00<00:00, 1919.15chunk/s]

  0%|                                                  | 0/4 [00:00<?, ?chunk/s]
100%|########################################| 4/4 [00:00<00:00, 2291.97chunk/s]

  0%|                                                  | 0/8 [00:00<?, ?chunk/s]
100%|#########################################| 8/8 [00:00<00:00, 396.61chunk/s]

  0%|                                                 | 0/20 [00:00<?, ?chunk/s]
 65%|#########################3             | 13/20 [00:00<00:00, 114.97chunk/s]
100%|########################################| 20/20 [00:00<00:00, 82.45chunk/s]

  0%|                                                  | 0/2 [00:00<?, ?chunk/s]
100%|########################################| 2/2 [00:00<00:00, 2123.70chunk/s]

  0%|                                                 | 0/40 [00:00<?, ?chunk/s]
 32%|############6                          | 13/40 [00:00<00:00, 114.72chunk/s]
 62%|#########################               | 25/40 [00:00<00:00, 69.94chunk/s]
 82%|#################################       | 33/40 [00:00<00:00, 52.50chunk/s]
100%|########################################| 40/40 [00:00<00:00, 64.56chunk/s]

Some tabular summary statistics of the posterior samples:

import liesel.goose as gs

results = engine.get_results()
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
concentration_p0_beta (0,) kernel_00 0.104 0.054 0.016 0.103 0.193 4000 372.271 988.761 1.004
(1,) kernel_00 0.964 0.099 0.796 0.967 1.121 4000 207.805 645.642 1.010
loc_np0_beta (0,) kernel_03 0.469 0.207 0.121 0.469 0.807 4000 54.068 156.325 1.067
(1,) kernel_03 -0.147 0.129 -0.358 -0.149 0.067 4000 51.923 105.754 1.081
(2,) kernel_03 0.473 0.139 0.241 0.472 0.696 4000 85.108 129.521 1.037
(3,) kernel_03 -0.008 0.073 -0.132 -0.005 0.113 4000 61.796 168.336 1.093
(4,) kernel_03 0.472 0.070 0.362 0.470 0.589 4000 64.460 135.078 1.074
(5,) kernel_03 0.458 0.031 0.412 0.457 0.512 4000 87.095 127.444 1.023
(6,) kernel_03 -5.911 0.031 -5.964 -5.913 -5.862 4000 75.689 136.909 1.069
(7,) kernel_03 0.375 0.069 0.253 0.375 0.488 4000 87.037 169.840 1.040
(8,) kernel_03 -1.794 0.026 -1.837 -1.794 -1.753 4000 87.187 160.277 1.059
loc_np0_tau2 () kernel_02 5.967 4.374 2.276 4.946 12.862 4000 3610.352 3848.888 1.001
loc_p0_beta (0,) kernel_04 -0.027 0.002 -0.031 -0.027 -0.023 4000 90.065 390.131 1.050
scale_p0_beta (0,) kernel_01 -3.093 0.059 -3.190 -3.090 -2.999 4000 151.457 348.065 1.057
(1,) kernel_01 1.197 0.081 1.067 1.196 1.332 4000 246.643 530.955 1.038

Error summary:

count relative
kernel error_code error_msg phase
kernel_00 90 nan acceptance prob warmup 69 0.017
posterior 1 0.000
kernel_01 90 nan acceptance prob warmup 32 0.008
posterior 0 0.000
kernel_03 90 nan acceptance prob warmup 25 0.006
posterior 0 0.000
kernel_04 90 nan acceptance prob warmup 23 0.006
posterior 0 0.000

And the corresponding trace plots:

fig = gs.plot_trace(results, "loc_p0_beta")

fig = gs.plot_trace(results, "loc_np0_tau2")

fig = gs.plot_trace(results, "loc_np0_beta")

fig = gs.plot_trace(results, "scale_p0_beta")

fig = gs.plot_trace(results, "concentration_p0_beta")

We need to reset the index of the summary data frame before we can transfer it to R.

summary = gs.Summary(results).to_dataframe().reset_index()

After transferring the summary data frame to R, we can process it with packages like dplyr and ggplot2. Here is a visualization of the estimated spline vs. the true function:

library(dplyr)
Attaching package: 'dplyr'

The following objects are masked from 'package:stats':

    filter, lag

The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
library(ggplot2)
library(reticulate)
Warning: package 'reticulate' was built under R version 4.4.1
summary <- py$summary

beta <- summary %>%
  filter(variable == "loc_np0_beta") %>%
  group_by(var_index) %>%
  summarize(mean = mean(mean)) %>%
  ungroup()

beta <- beta$mean
X <- py_to_r(model$vars["loc_np0_X"]$value)
estimate <- X %*% beta

true <- sin(2 * pi * x0)

ggplot(data.frame(x0 = x0, estimate = estimate, true = true)) +
  geom_line(aes(x0, estimate), color = palette()[2]) +
  geom_line(aes(x0, true), color = palette()[4]) +
  ggtitle("Estimated spline (red) vs. true function (blue)") +
  ylab("f") +
  theme_minimal()