"""
Posterior statistics and diagnostics.
"""
from __future__ import annotations
import typing
from collections.abc import Sequence
from typing import Any, NamedTuple
import arviz as az
import jax.numpy as jnp
import jaxlib.xla_extension
import numpy as np
import pandas as pd
from deprecated.sphinx import deprecated
from liesel.goose.engine import ErrorLog, SamplingResults
from liesel.goose.pytree import slice_leaves, stack_leaves
from liesel.goose.types import Position
from liesel.option import Option
[docs]class ErrorSummaryForOneCode(NamedTuple):
error_code: int
error_msg: str
count_per_chain: np.ndarray
count_per_chain_posterior: None
ErrorSummary = dict[str, dict[int, ErrorSummaryForOneCode]]
"""
See docstring of ``_make_error_summary``.
"""
def _make_error_summary(
error_log: ErrorLog,
posterior_error_log: Option[ErrorLog],
) -> ErrorSummary:
"""
Creates an error summary from the error log.
The returned value looks like this::
{
kernel_identifier: {
error_code: (error_code, error_msg, count, count_in_posterior),
error_code: (error_code, error_msg, count, count_in_posterior),
...
},
...
}
The ``error_msg`` is the empty string if the kernel class is not supplied in the
``error_log``.
"""
error_summary = {}
for kel in error_log.values():
counter_dict: dict[int, np.ndarray] = {}
# calculate the overall counts
ec_unique = np.unique(kel.error_codes)
for ec in ec_unique:
if ec == 0:
continue
occurences_per_chain = np.sum(kel.error_codes == ec, axis=1)
counter_dict[ec] = occurences_per_chain
krnl_summary: dict[int, ErrorSummaryForOneCode] = {}
for key, count in counter_dict.items():
ec = key
# type ignore is ok since the type must implement the kernel protocol.
error_msg = kel.kernel_cls.map_or(
"", lambda krn_cls: krn_cls.error_book[ec] # type: ignore
)
krnl_summary[ec] = ErrorSummaryForOneCode(ec, error_msg, count, None)
# calculate the counts in the posterior
if posterior_error_log.is_some():
posterior_error_log_unwrapped = posterior_error_log.unwrap()
kel_post = posterior_error_log_unwrapped[kel.kernel_ident]
for ec in ec_unique:
if ec == 0:
continue
occurences_per_chain = np.sum(kel_post.error_codes == ec, axis=1)
krnl_summary[ec] = krnl_summary[ec]._replace(
count_per_chain_posterior=occurences_per_chain
)
error_summary[kel.kernel_ident] = krnl_summary
return error_summary
[docs]class Summary:
"""
A summary object.
Offers two main use cases:
1. View an overall summary by printing a summary instance, including a summary table
of the posterior samples and a summary of sammpling errors.
2. Programmatically access summary statistics via
``quantities[quantity_name][var_name]``. Please refer to the documentation of the
attribute :attr:`.quantities` for details.
Additionally, the summary object can be turned into a :class:`~pandas.DataFrame`
using :meth:`.to_dataframe`.
Parameters
----------
results
The sampling results to summarize.
additional_chain
can be supplied to add more parameters to the summary output. Must be a position
chain which matches chain and time dimension of the posterior chain as returned
by :meth:`.SamplingResults.get_posterior_samples`.
hdi_prob
Level on which to return posterior highest density intervals.
selected, deselected
Allow to get a summary only for a subset of the position keys.
per_chain
If *True*, the summary is calculated on a per-chain basis. Certain measures like
``rhat`` are not available if ``per_chain`` is *True*.
Notes
-----
This class is still considered experimental. The API may still undergo larger
changes.
"""
per_chain: bool
"""
Whether results are summarized for individual chains (*True*), or aggregated
over chains (*False*).
"""
quantities: dict[str, dict[str, np.ndarray]]
"""
Dict of summarizing quantities.
Built up in hierarchies as. Let ``summary`` be a :class:`.Summary` instance. The
hierarchy is::
q = summary.quantities["quantity_name"]["parameter_name"]
The extracted object is an ``np.ndarray``. If ``per_chain=True``, the arrays for
the ``"quantile"`` and ``"hdi"`` quantities have the following dimensions:
1. First index refers to the chain
2. Second index refers to the quantile/interval
3. Third and subsequent indices refer to individual parameters.
If ``per_chain=True``, the arrays for the other quantiles have the dimensions:
1. First index refers to the chain
2. Second and subsequent indices refer to individual parameters.
If ``per_chain=False``, the first index is removed for all quantities.
"""
config: dict
"""
A dictionary of config settings for this summary object. Should NOT be changed
after initialization; such changes have no effect on the computed summary values.
"""
sample_info: dict
"""
Dictionary of meta-information about the mcmc samples used to create this summary
object.
Contains ``num_chains``, ``sample_size_per_chain``, and ``warmup_size_per_chain``.
"""
error_summary: ErrorSummary
"""
Contains error information for each kernel.
"""
kernels_by_pos_key: dict[str, str]
"""
A dict, linking parameter names (the keys) to the kernel identifier (the values).
The identifier refers to the kernel that was used to sample the respective
parameter.
"""
def __init__(
self,
results: SamplingResults,
additional_chain: Position | None = None,
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
selected: list[str] | None = None,
deselected: list[str] | None = None,
per_chain: bool = False,
):
posterior_chain = results.get_posterior_samples()
if additional_chain:
for k, v in additional_chain.items():
posterior_chain[k] = v
if selected:
posterior_chain = Position(
{
key: value
for key, value in posterior_chain.items()
if key in selected
}
)
if deselected is not None:
for key in deselected:
del posterior_chain[key]
# get some general infos on the sampling
param_chain = next(iter(posterior_chain.values()))
epochs = results.positions.get_epochs()
warmup_size = np.sum(
[epoch.duration for epoch in epochs if epoch.type.is_warmup(epoch.type)]
)
sample_info = {
"num_chains": param_chain.shape[0],
"sample_size_per_chain": param_chain.shape[1],
"warmup_size_per_chain": warmup_size,
}
# convert everything to numpy array
for key in posterior_chain:
posterior_chain[key] = np.asarray(posterior_chain[key])
# calculate quantiles either per chain and merge the results or all at once
single_chain_summaries = []
if per_chain:
for chain_idx in range(sample_info["num_chains"]):
single_chain = slice_leaves(
posterior_chain, jnp.s_[None, chain_idx, ...]
)
qdict = _create_quantity_dict(single_chain, quantiles, hdi_prob)
single_chain_summaries.append(qdict)
quantities = stack_leaves(single_chain_summaries, axis=0)
else:
quantities = _create_quantity_dict(posterior_chain, quantiles, hdi_prob)
config = {
"quantiles": quantiles,
"hdi_prob": hdi_prob,
"chains_merged": not per_chain,
}
error_summary = _make_error_summary(
results.get_error_log(False).unwrap(), results.get_error_log(True)
)
self.per_chain = per_chain
self.quantities = quantities
self.config = config
self.sample_info = sample_info
self.error_summary = error_summary
self.kernels_by_pos_key = results.get_kernels_by_pos_key()
[docs] def to_dataframe(self) -> pd.DataFrame:
"""Turns Summary object into a :class:`~pandas.DataFrame` object."""
# don't change the original data
quants = self.quantities.copy()
# make new entries for the quantiles
if self.per_chain:
for i, q in enumerate(self.config["quantiles"]):
quants[f"q_{q}"] = {
k: v[:, i, ...] for k, v in quants["quantile"].items()
}
quants["hdi_low"] = {k: v[:, 0, ...] for k, v in quants["hdi"].items()}
quants["hdi_high"] = {k: v[:, 1, ...] for k, v in quants["hdi"].items()}
else:
for i, q in enumerate(self.config["quantiles"]):
quants[f"q_{q}"] = {k: v[i, ...] for k, v in quants["quantile"].items()}
quants["hdi_low"] = {k: v[0, ...] for k, v in quants["hdi"].items()}
quants["hdi_high"] = {k: v[1, ...] for k, v in quants["hdi"].items()}
# remove the old entries
del quants["quantile"]
del quants["hdi"]
# create one row per entry
df_dict = {}
for var in quants["mean"].keys():
it = np.nditer(quants["mean"][var], flags=["multi_index"])
for _ in it:
var_fqn = (
var if len(it.multi_index) == 0 else f"{var}{list(it.multi_index)}"
)
quant_per_elem: dict[str, Any] = {}
quant_per_elem["variable"] = var
quant_per_elem["kernel"] = self.kernels_by_pos_key.get(var, "-")
if self.config["chains_merged"]:
quant_per_elem["var_index"] = it.multi_index
quant_per_elem["sample_size"] = (
self.sample_info["sample_size_per_chain"]
* self.sample_info["num_chains"]
)
else:
quant_per_elem["chain_index"] = it.multi_index[0]
quant_per_elem["var_index"] = it.multi_index[1:]
quant_per_elem["sample_size"] = self.sample_info[
"sample_size_per_chain"
]
for quant_name, quant_dict in quants.items():
quant_per_elem[quant_name] = quant_dict[var][it.multi_index]
# convert DeviceArrays (scalar) to floats so that
# pandas treats them correctly
for key, val in quant_per_elem.items():
if type(val) == jaxlib.xla_extension.DeviceArray:
# make mypy happy
val = typing.cast(jnp.ndarray, val)
# value should be a scalar
assert val.shape == ()
# convert to float32
val = np.atleast_1d(np.asarray(val))[0]
quant_per_elem[key] = val
df_dict[var_fqn] = quant_per_elem
# convert to dataframe and use varname as index
df = pd.DataFrame.from_dict(df_dict, orient="index")
df = df.reset_index()
df = df.rename(columns={"index": "var_fqn"})
df = df.set_index("variable")
return df
def _param_df(self):
df = self.to_dataframe()
df.index.name = "parameter"
df = df.rename(columns={"var_index": "index"})
df = df.set_index("index", append=True)
qtls = [f"q_{qtl}" for qtl in self.config["quantiles"]]
cols = (
["kernel", "mean", "sd"]
+ qtls
+ ["sample_size", "ess_bulk", "ess_tail", "rhat"]
)
cols = [col for col in cols if col in df.columns]
df = df[cols]
return df
def _error_df(self, per_chain=False):
# fmt: off
df = pd.concat({
kernel: pd.DataFrame.from_dict(code_summary, orient="index")
for kernel, code_summary in self.error_summary.items()
})
# fmt: on
if df.empty:
return df
df = df.reset_index(level=1, drop=True)
df["error_code"] = df["error_code"].astype(int)
df = df.set_index(["error_code", "error_msg"], append=True)
df.index.names = ["kernel", "error_code", "error_msg"]
# fmt: off
df = df.rename(columns={
"count_per_chain": "total",
"count_per_chain_posterior": "posterior",
})
# fmt: on
df = df.explode(["total", "posterior"])
df["warmup"] = df["total"] - df["posterior"]
df = df.drop(columns="total")
df = df.melt(
value_vars=["warmup", "posterior"],
var_name="phase",
value_name="count",
ignore_index=False,
)
df["phase"] = pd.Categorical(df["phase"], categories=["warmup", "posterior"])
df = df.set_index("phase", append=True)
df["chain"] = df.groupby(level=[0, 1, 2, 3]).cumcount()
df = df.set_index("chain", append=True)
df = df.sort_index()
df["sample_size"] = None
warmup_size = self.sample_info["warmup_size_per_chain"]
posterior_size = self.sample_info["sample_size_per_chain"]
df.loc[pd.IndexSlice[:, :, :, "warmup"], "sample_size"] = warmup_size
df.loc[pd.IndexSlice[:, :, :, "posterior"], "sample_size"] = posterior_size
df["relative"] = df["count"] / df["sample_size"]
df = df.drop(columns="sample_size")
if not per_chain:
df = df.groupby(level=[0, 1, 2, 3], observed=True)
df = df.aggregate({"count": "sum", "relative": "mean"})
df = df.sort_index()
return df
def __repr__(self):
param_df = self._param_df()
error_df = self._error_df()
txt = "Parameter summary:\n\n" + repr(param_df)
if not error_df.empty:
txt += "\n\nError summary:\n\n" + repr(error_df)
return txt
def _repr_html_(self):
param_df = self._param_df()
error_df = self._error_df()
html = "\n<p><strong>Parameter summary:</strong></p>\n" + param_df.to_html()
if not error_df.empty:
html += "\n<p><strong>Error summary:</strong></p>\n" + error_df.to_html()
html += "\n"
return html
def _repr_markdown_(self):
param_df = self._param_df()
error_df = self._error_df()
try:
param_md = param_df.to_markdown()
error_md = error_df.to_markdown()
except ImportError:
param_md = f"```\n{repr(param_df)}\n```"
error_md = f"```\n{repr(error_df)}\n```"
md = "\n\n**Parameter summary:**\n\n" + param_md
if not error_df.empty:
md += "\n\n**Error summary:**\n\n" + error_md
md += "\n\n"
return md
def __str__(self):
return str(self.to_dataframe())
[docs] @classmethod
@deprecated(reason="Functionality moved directly to the __init__.", version="0.1.4")
def from_result(
cls,
result: SamplingResults,
additional_chain: Position | None = None,
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
selected: list[str] | None = None,
deselected: list[str] | None = None,
per_chain=False,
) -> Summary:
"""
Alias for :meth:`.from_results` for backwards compatibility.
In addition to the name, there are two further subtle differences to
:meth:`.from_results`.
- The argument ``result`` is in singular. The method :meth:`.from_results` uses
the plural instead.
- The argument ``result`` is of type :class:`.SamplingResult`, which itself is
an alias for :meth:`.SamplingResults`.
"""
return cls.from_results(
results=result,
additional_chain=additional_chain,
quantiles=quantiles,
hdi_prob=hdi_prob,
selected=selected,
deselected=deselected,
per_chain=per_chain,
)
[docs] @staticmethod
@deprecated(reason="Functionality moved directly to the __init__.", version="0.1.4")
def from_results(
results: SamplingResults,
additional_chain: Position | None = None,
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
hdi_prob: float = 0.9,
selected: list[str] | None = None,
deselected: list[str] | None = None,
per_chain=False,
) -> Summary:
"""
Creates a :class:`.Summary` object from a results object.
"""
return Summary(
results=results,
additional_chain=additional_chain,
quantiles=quantiles,
hdi_prob=hdi_prob,
selected=selected,
deselected=deselected,
per_chain=per_chain,
)
def _create_quantity_dict(
chain: Position, quantiles: Sequence[float], hdi_prob: float
) -> dict[str, dict[str, np.ndarray]]:
azchain = az.convert_to_inference_data(chain).posterior
# calculate quantities
mean = azchain.mean(dim=["chain", "draw"])
var = azchain.var(dim=["chain", "draw"])
sd = azchain.std(dim=["chain", "draw"])
quantile = azchain.quantile(q=quantiles, dim=["chain", "draw"])
hdi = az.hdi(azchain, hdi_prob=hdi_prob)
ess_bulk = az.ess(azchain, method="bulk")
ess_tail = az.ess(azchain, method="tail")
mcse_mean = az.mcse(azchain, method="mean")
mcse_sd = az.mcse(azchain, method="sd")
# place quantities in dict
quantities = {
"mean": mean,
"var": var,
"sd": sd,
"quantile": quantile,
"hdi": hdi,
"rhat": None,
"ess_bulk": ess_bulk,
"ess_tail": ess_tail,
"mcse_mean": mcse_mean,
"mcse_sd": mcse_sd,
}
if azchain.chain.size > 1:
quantities["rhat"] = az.rhat(azchain)
else:
del quantities["rhat"]
# convert to simple dict[str, np.ndarray]
for key, val in quantities.items():
quantities[key] = {k: v.values for k, v in val.data_vars.items()}
# hdi shape BEFORE
# VarIDX --- HDI
# special treatment for hdi since the function uses the last axis to refer
# to the quantile
for k, v in quantities["hdi"].items():
quantities["hdi"][k] = np.moveaxis(v, -1, 0)
# hdi shape AFTER
# HDI --- VarIDX
return quantities