Source code for liesel.goose.summary_m

"""
Posterior statistics and diagnostics.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Literal, NamedTuple

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from liesel.__version__ import __version__
from liesel.goose.engine import ErrorLog, SamplingResults
from liesel.goose.epoch import EpochType
from liesel.goose.pytree import slice_leaves, stack_leaves
from liesel.goose.types import Array, Position, TransitionInfo
from liesel.option import Option


class ErrorSummaryForOneCode(NamedTuple):
    error_code: int
    error_msg: str
    count_per_chain: np.ndarray
    count_per_chain_posterior: None


ErrorSummary = dict[str, dict[int, ErrorSummaryForOneCode]]
"""
See docstring of ``_make_error_summary``.
"""


def _make_error_summary(
    error_log: ErrorLog,
    posterior_error_log: Option[ErrorLog],
) -> ErrorSummary:
    """
    Creates an error summary from the error log.

    The returned value looks like this::

        {
            kernel_identifier: {
                error_code: (error_code, error_msg, count, count_in_posterior),
                error_code: (error_code, error_msg, count, count_in_posterior),
                ...
            },
            ...
        }

    The ``error_msg`` is the empty string if the kernel class is not supplied in the
    ``error_log``.
    """
    error_summary = {}
    for kel in error_log.values():
        counter_dict: dict[int, np.ndarray] = {}

        # calculate the overall counts
        ec_unique = np.unique(kel.error_codes)
        for ec in ec_unique:
            if ec == 0:
                continue
            occurences_per_chain = np.sum(kel.error_codes == ec, axis=1)
            counter_dict[ec] = occurences_per_chain

        krnl_summary: dict[int, ErrorSummaryForOneCode] = {}
        for key, count in counter_dict.items():
            ec = key
            # type ignore is ok since the type must implement the kernel protocol.
            error_msg = kel.kernel_cls.map_or(
                "",
                lambda krn_cls: krn_cls.error_book[ec],  # type: ignore
            )
            krnl_summary[ec] = ErrorSummaryForOneCode(ec, error_msg, count, None)

        # calculate the counts in the posterior
        if posterior_error_log.is_some():
            posterior_error_log_unwrapped = posterior_error_log.unwrap()
            kel_post = posterior_error_log_unwrapped[kel.kernel_ident]
            for ec in ec_unique:
                if ec == 0:
                    continue
                occurences_per_chain = np.sum(kel_post.error_codes == ec, axis=1)
                krnl_summary[ec] = krnl_summary[ec]._replace(
                    count_per_chain_posterior=occurences_per_chain
                )

        error_summary[kel.kernel_ident] = krnl_summary

    return error_summary


SummaryQuantities = Literal[
    "mean",
    "sd",
    "var",
    "quantiles",
    "hdi",
    "ess_bulk",
    "ess_tail",
    "rhat",
    "mcse_mean",
    "mcse_sd",
]

summary_quantities: Sequence[SummaryQuantities] = (
    "mean",
    "sd",
    "var",
    "quantiles",
    "hdi",
    "ess_bulk",
    "ess_tail",
    "rhat",
    "mcse_mean",
    "mcse_sd",
)


def _summarize_acceptance_probabilities(
    transition_infos: dict[str, TransitionInfo], phase: str
) -> list[dict[str, Any]]:
    data = []
    for k, tinfo in transition_infos.items():
        ap = jnp.asarray(tinfo.acceptance_prob)
        pm = jnp.asarray(tinfo.position_moved)
        chains, _ = ap.shape
        for c in range(chains):
            chain = {"kernel": k, "phase": phase, "chain": c}
            chain["acceptance_probability"] = float(ap[c, ...].mean())
            chain["position_moved"] = jnp.mean(pm[c, ...])
            if float(chain["position_moved"]) > 1.0:  # type: ignore # spurious warning
                if int(chain["position_moved"].round()) == 99:  # type: ignore
                    chain["position_moved"] = jnp.nan
            data.append(chain)
    return data


def summarize_acceptance_probabilities(
    results: SamplingResults,
) -> list[dict[str, Any]]:
    warmup = _summarize_acceptance_probabilities(
        results.get_warmup_transition_infos(), "warmup"
    )
    posterior = _summarize_acceptance_probabilities(
        results.get_posterior_transition_infos(), "posterior"
    )
    return warmup + posterior


[docs] class Summary: """ Posterior summary and diagnostics for :class:`.SamplingResults`. Offers two main use cases: 1. View an overall summary by printing a summary instance, including a summary table of the posterior samples and a summary of sampling errors. 2. Programmatically access summary statistics via ``quantities[quantity_name][var_name]``. Please refer to the documentation of the attribute :attr:`.quantities` for details. Additionally, the summary object can be turned into a :class:`~pandas.DataFrame` using :meth:`.to_dataframe`. If ``per_chain=False``, statistics are computed over all posterior chains and draws. If ``per_chain=True``, each chain is summarized separately. The low-level computations for HDIs, effective sample sizes, R-hat, and Monte Carlo standard errors are delegated to `ArviZ <https://python.arviz.org/>`_. By default, the summary contains the following statistics: - ``mean``: Posterior mean. - ``sd``: Posterior standard deviation. - ``var``: Posterior variance. - ``quantiles``: Posterior quantiles at the probabilities given by ``quantiles``. These are stored as ``"quantile"`` in :attr:`.quantities` and become columns named ``q_<probability>`` in :meth:`.to_dataframe`. - ``hdi``: Highest density interval with probability mass ``hdi_prob``. This is the narrowest posterior interval reported by ArviZ at that probability level. In :meth:`.to_dataframe`, it becomes ``hdi_low`` and ``hdi_high``. - ``ess_bulk``: Bulk effective sample size, a diagnostic for Monte Carlo precision in the central part of the posterior distribution. - ``ess_tail``: Tail effective sample size, a diagnostic for Monte Carlo precision in the posterior tails. - ``rhat``: Rank-normalized split R-hat, a between-chain convergence diagnostic. Values close to 1 indicate better agreement between chains. This statistic is only computed when more than one chain is summarized together. - ``mcse_mean``: Monte Carlo standard error of the posterior mean. - ``mcse_sd``: Monte Carlo standard error of the posterior standard deviation. Use ``which`` to compute only a subset of these statistics. Parameters ---------- results The sampling results to summarize. additional_chain Can be supplied to add more parameters to the summary output. Must be a position chain which matches chain and time dimension of the posterior chain as returned by :meth:`~.goose.SamplingResults.get_posterior_samples`. quantiles Posterior quantile probabilities to compute when ``"quantiles"`` is included in ``which``. hdi_prob Posterior probability mass of the highest density interval to compute when ``"hdi"`` is included in ``which``. selected, deselected Allow to get a summary only for a subset of the position keys. per_chain If *True*, the summary is calculated on a per-chain basis. Certain measures like ``rhat`` are not available if ``per_chain`` is *True*. which Names of the summary statistics to compute. Supported values are ``"mean"``, ``"sd"``, ``"var"``, ``"quantiles"``, ``"hdi"``, ``"ess_bulk"``, ``"ess_tail"``, ``"rhat"``, ``"mcse_mean"``, and ``"mcse_sd"``. Notes ----- This class is still considered experimental. The API may still undergo larger changes. """ per_chain: bool """ Whether results are summarized for individual chains (*True*), or aggregated over chains (*False*). """ quantities: dict[str, dict[str, np.ndarray]] """ Dict of summarizing quantities. Let ``summary`` be a :class:`.Summary` instance. The hierarchy is:: q = summary.quantities["quantity_name"]["parameter_name"] Available quantity names are ``"mean"``, ``"sd"``, ``"var"``, ``"quantile"``, ``"hdi"``, ``"ess_bulk"``, ``"ess_tail"``, ``"rhat"``, ``"mcse_mean"``, and ``"mcse_sd"``, depending on the ``which`` argument. Note that ``which`` uses ``"quantiles"`` to request quantiles, while :attr:`.quantities` stores the result under ``"quantile"``. The extracted object is an ``np.ndarray``. If ``per_chain=True``, the arrays for the ``"quantile"`` and ``"hdi"`` quantities have the following dimensions: 1. First index refers to the chain 2. Second index refers to the quantile/interval 3. Third and subsequent indices refer to individual parameters. If ``per_chain=True``, the arrays for the other quantities have the dimensions: 1. First index refers to the chain 2. Second and subsequent indices refer to individual parameters. If ``per_chain=False``, the first index is removed for all quantities. """ config: dict """ A dictionary of config settings for this summary object. Should NOT be changed after initialization; such changes have no effect on the computed summary values. """ sample_info: dict """ Dictionary of meta-information about the mcmc samples used to create this summary object. Contains ``num_chains``, ``sample_size_per_chain``, and ``warmup_size_per_chain``. """ error_summary: ErrorSummary """ Contains error information for each kernel. """ kernels_by_pos_key: dict[str, str] """ A dict, linking parameter names (the keys) to the kernel identifier (the values). The identifier refers to the kernel that was used to sample the respective parameter. """ liesel_version: str """ The specific version of Liesel used to produce the results. """ def __init__( self, results: SamplingResults, additional_chain: Position | None = None, quantiles: Sequence[float] = (0.05, 0.5, 0.95), hdi_prob: float = 0.9, selected: list[str] | None = None, deselected: list[str] | None = None, per_chain: bool = False, which: Sequence[SummaryQuantities] = summary_quantities, ): if not which: raise ValueError( f"Argument 'which' must not be empty. " f"Supported keys are: {summary_quantities}" ) for _which_key in which: if _which_key not in summary_quantities: raise ValueError( f"Key {which} in 'which' is not supported. " f"Supported keys are: {summary_quantities}" ) posterior_chain = results.get_posterior_samples() if additional_chain: for k, v in additional_chain.items(): posterior_chain[k] = v if selected: posterior_chain = Position( { key: value for key, value in posterior_chain.items() if key in selected } ) if deselected is not None: for key in deselected: del posterior_chain[key] # get some general infos on the sampling param_chain = next(iter(posterior_chain.values())) epochs = results.positions.get_epochs() warmup_size = np.sum( [ int(epoch.duration / epoch.thinning) for epoch in epochs if epoch.type.is_warmup(epoch.type) ] ) thinning_warmup = np.unique( [epoch.thinning for epoch in epochs if epoch.type.is_warmup(epoch.type)] ) thinning_posterior = np.unique( [epoch.thinning for epoch in epochs if epoch.type is EpochType.POSTERIOR] ) sample_info = { "num_chains": param_chain.shape[0], "sample_size_per_chain": param_chain.shape[1], "warmup_size_per_chain": warmup_size, "thinning_warmup": thinning_warmup.squeeze(), "thinning_posterior": thinning_posterior.squeeze(), } # convert everything to numpy array for key in posterior_chain: posterior_chain[key] = np.asarray(posterior_chain[key]) # calculate quantiles either per chain and merge the results or all at once single_chain_summaries = [] if per_chain: for chain_idx in range(sample_info["num_chains"]): single_chain = slice_leaves( posterior_chain, jnp.s_[None, chain_idx, ...] ) qdict = _create_quantity_dict(single_chain, quantiles, hdi_prob, which) single_chain_summaries.append(qdict) quantities = stack_leaves(single_chain_summaries, axis=0) else: quantities = _create_quantity_dict( posterior_chain, quantiles, hdi_prob, which ) config = { "quantiles": quantiles, "hdi_prob": hdi_prob, "chains_merged": not per_chain, } error_summary = _make_error_summary( results.get_error_log(False).unwrap(), results.get_error_log(True) ) pos_keys_by_kernels = [] for k, v in results.get_pos_keys_by_kernels().items(): pos_keys_by_kernels.append({"kernel": k, "positions": ", ".join(v)}) if len(pos_keys_by_kernels) == 1: posdf = pd.DataFrame(pos_keys_by_kernels, index=pd.Index([0])) else: posdf = pd.DataFrame(pos_keys_by_kernels) posdf = posdf.set_index(["kernel"]) self._which = which self.per_chain = per_chain self.quantities = quantities self.config = config self.sample_info = sample_info self.error_summary = error_summary self.pos_keys_by_kernels_df = posdf self._acceptance_prob_summary = summarize_acceptance_probabilities(results) self.kernels_by_pos_key = results.get_kernels_by_pos_key() self.liesel_version = __version__
[docs] def to_dataframe(self) -> pd.DataFrame: """Turns Summary object into a :class:`~pandas.DataFrame` object.""" # don't change the original data quants = self.quantities.copy() # make new entries for the quantiles if self.per_chain: if "quantiles" in self._which: for i, q in enumerate(self.config["quantiles"]): quants[f"q_{q}"] = { k: v[:, i, ...] for k, v in quants["quantile"].items() } if "hdi" in self._which: quants["hdi_low"] = {k: v[:, 0, ...] for k, v in quants["hdi"].items()} quants["hdi_high"] = {k: v[:, 1, ...] for k, v in quants["hdi"].items()} else: if "quantiles" in self._which: for i, q in enumerate(self.config["quantiles"]): quants[f"q_{q}"] = { k: v[i, ...] for k, v in quants["quantile"].items() } if "hdi" in self._which: quants["hdi_low"] = {k: v[0, ...] for k, v in quants["hdi"].items()} quants["hdi_high"] = {k: v[1, ...] for k, v in quants["hdi"].items()} # remove the old entries if "hdi" in self._which: del quants["hdi"] if "quantiles" in self._which: del quants["quantile"] # create one row per entry df_dict = {} first_quant = list(quants.values())[0] for var in first_quant.keys(): it = np.nditer(first_quant[var], flags=["multi_index"]) for _ in it: var_fqn = ( var if len(it.multi_index) == 0 else f"{var}{list(it.multi_index)}" ) quant_per_elem: dict[str, Any] = {} quant_per_elem["variable"] = var quant_per_elem["kernel"] = self.kernels_by_pos_key.get(var, "-") if self.config["chains_merged"]: quant_per_elem["var_index"] = it.multi_index quant_per_elem["sample_size"] = ( self.sample_info["sample_size_per_chain"] * self.sample_info["num_chains"] ) else: quant_per_elem["chain_index"] = it.multi_index[0] quant_per_elem["var_index"] = it.multi_index[1:] quant_per_elem["sample_size"] = self.sample_info[ "sample_size_per_chain" ] for quant_name, quant_dict in quants.items(): quant_per_elem[quant_name] = quant_dict[var][it.multi_index] # convert jax.Arrays (scalar) to floats so that pandas treats them # correctly for key, val in quant_per_elem.items(): if isinstance(val, jax.Array): # value should be a scalar assert val.shape == () # replace dict element with value casted to float32 quant_per_elem[key] = float(val) df_dict[var_fqn] = quant_per_elem # convert to dataframe and use varname as index df = pd.DataFrame.from_dict(df_dict, orient="index") df = df.reset_index() df = df.rename(columns={"index": "var_fqn"}) df = df.set_index("variable") return df
[docs] def aggregate_diagnostics( self, by: Literal[ "min/max", "mean", "median", "std", "var", "min", "max" ] = "min/max", ) -> pd.DataFrame: """ Aggregates effective sample sizes (ESS) and rhat. Parameters ---------- by How to aggregate. The three current options are: - ``"min/max"``: Aggregate ESS by taking the minimum ESS per parameter block, and the rhat by taking the maximum rhat per parameter block. This corresponds to a worst-case summary. - ``"mean"``: Aggregate ESS and rhat by averaging inside parameter blocks. - ``"median"``: Aggregate ESS and rhat by taking the median inside parameter blocks. - ``"std"``: Aggregate ESS and rhat by taking the standard deviation inside parameter blocks. - ``"var"``: Aggregate ESS and rhat by taking the variance inside parameter blocks. - ``"min"``: Aggregate ESS and rhat by taking the minimum inside parameter blocks. - ``"max"``: Aggregate ESS and rhat by taking the maximum inside parameter blocks. Notes ------ If :attr:`.per_chain` is ``True``, rhat cannot be computed and there is not present in the output dataframe. """ df = self.to_dataframe() df.index.name = "parameter" if by == "min/max": by_ess = "min" by_rhat = "max" else: by_ess = by by_rhat = by if self.per_chain: diagnostics = ( df.loc[:, ["chain_index", "ess_bulk", "ess_tail"]] .groupby(["parameter", "chain_index"]) .agg( ess_bulk=("ess_bulk", by_ess), ess_tail=("ess_tail", by_ess), ) ) else: diagnostics = ( df.loc[:, ["rhat", "ess_bulk", "ess_tail"]] .groupby("parameter") .agg( ess_bulk=("ess_bulk", by_ess), ess_tail=("ess_tail", by_ess), rhat=("rhat", by_rhat), ) ) if by == "min/max": diagnostics["aggregated_by"] = "min (ess); max (rhat)" else: diagnostics["aggregated_by"] = by return diagnostics
def _param_df(self): df = self.to_dataframe() df.index.name = "parameter" df = df.rename(columns={"var_index": "index"}) df = df.set_index("index", append=True) qtls = [f"q_{qtl}" for qtl in self.config["quantiles"]] cols = ( ["kernel", "mean", "sd"] + qtls + ["sample_size", "ess_bulk", "ess_tail", "rhat"] ) cols = [col for col in cols if col in df.columns] df = df[cols] return df
[docs] def acceptance_prob_df(self) -> pd.DataFrame: """Returns an overview of acceptance probabilities as a dataframe.""" apdf = pd.DataFrame(self._acceptance_prob_summary) apdf = apdf.set_index(["kernel", "phase", "chain"]) apdf = apdf.join(self.pos_keys_by_kernels_df, on="kernel") apdf = apdf.reset_index().set_index(["kernel", "positions", "phase", "chain"]) if not self.per_chain: apdf = apdf.groupby(level=["kernel", "positions", "phase"]).agg( {"acceptance_probability": "mean", "position_moved": "mean"} ) return apdf
[docs] def error_df(self, per_chain: bool = False) -> pd.DataFrame: """ Returns an overview of the errors recorded during sampling as a dataframe. """ return self._error_df(per_chain=per_chain)
def _error_df(self, per_chain: bool = False) -> pd.DataFrame: # fmt: off error_summaries = {k: v for k, v in self.error_summary.items() if v} if error_summaries: df = pd.concat({ kernel: pd.DataFrame.from_dict(code_summary, orient="index") for kernel, code_summary in error_summaries.items() }) else: return pd.DataFrame() # fmt: on df = df.reset_index(level=1, drop=True) df["error_code"] = df["error_code"].astype(int) df = df.set_index(["error_code", "error_msg"], append=True) df.index.names = ["kernel", "error_code", "error_msg"] # fmt: off df = df.rename(columns={ "count_per_chain": "total", "count_per_chain_posterior": "posterior", }) # fmt: on df = df.explode(["total", "posterior"]) df["warmup"] = df["total"] - df["posterior"] df = df.drop(columns="total") df = df.melt( value_vars=["warmup", "posterior"], var_name="phase", value_name="count", ignore_index=False, ) df["phase"] = pd.Categorical(df["phase"], categories=["warmup", "posterior"]) df = df.set_index("phase", append=True) df["chain"] = df.groupby(level=[0, 1, 2, 3], observed=True).cumcount() df = df.set_index("chain", append=True) df = df.sort_index() df["sample_size"] = None warmup_size = self.sample_info["warmup_size_per_chain"] posterior_size = self.sample_info["sample_size_per_chain"] df.loc[pd.IndexSlice[:, :, :, "warmup"], "sample_size"] = warmup_size df.loc[pd.IndexSlice[:, :, :, "posterior"], "sample_size"] = posterior_size df["thinning"] = None warmup_thinning = self.sample_info["thinning_warmup"] posterior_thinning = self.sample_info["thinning_posterior"] df.loc[pd.IndexSlice[:, :, :, "warmup"], "thinning"] = warmup_thinning df.loc[pd.IndexSlice[:, :, :, "posterior"], "thinning"] = posterior_thinning df["sample_size_total"] = df["sample_size"] * df["thinning"] df["relative"] = df["count"] / df["sample_size_total"] df = ( df.join(self.pos_keys_by_kernels_df, on="kernel") .reset_index() .set_index(["kernel", "positions", "error_code", "error_msg", "phase"]) ) # df = df.drop(columns="sample_size") if not per_chain: df = df.groupby(level=[0, 1, 2, 3, 4], observed=True) df = df.aggregate( { "count": "sum", "relative": "mean", "sample_size": "sum", "sample_size_total": "sum", } ) df = df.sort_index() # re-order columns cols = ["count", "sample_size", "sample_size_total", "relative"] return df[cols] def __repr__(self): param_df = self._param_df() error_df = self._error_df() txt = "Parameter summary:\n\n" + repr(param_df) if not error_df.empty: txt += "\n\nError summary:\n\n" + repr(error_df) return txt def _repr_html_(self): param_df = self._param_df() apdf = self.acceptance_prob_df() error_df = self._error_df() html = "\n<p><strong>Parameter summary:</strong></p>\n" + param_df.to_html() html += "\n<p><strong>Acceptance probabilities:</strong></p>\n" + apdf.to_html() if not error_df.empty: html += "\n<p><strong>Error summary:</strong></p>\n" + error_df.to_html() html += "\n" return html def _repr_markdown_(self): param_df = self._param_df() apdf = self.acceptance_prob_df() error_df = self._error_df() try: param_md = param_df.to_markdown() apdf_md = apdf.to_markdown() error_md = error_df.to_markdown() except ImportError: param_md = f"```\n{repr(param_df)}\n```" apdf_md = f"```\n{repr(apdf)}\n```" error_md = f"```\n{repr(error_df)}\n```" md = "\n\n**Parameter summary:**\n\n" + param_md md += "\n\n**Acceptance probabilities:**\n\n" + apdf_md if not error_df.empty: md += "\n\n**Error summary:**\n\n" + error_md md += "\n\n" return md def __str__(self): return str(self.to_dataframe())
def _create_quantity_dict( chain: Position, quantiles: Sequence[float], hdi_prob: float, which: Sequence[SummaryQuantities] = summary_quantities, ) -> dict[str, dict[str, np.ndarray]]: azchain = az.from_dict({"posterior": chain})["posterior"].dataset quantities = {} # calculate quantities if "mean" in which: quantities["mean"] = azchain.mean(dim=["chain", "draw"]) if "var" in which: quantities["var"] = azchain.var(dim=["chain", "draw"]) if "sd" in which: quantities["sd"] = azchain.std(dim=["chain", "draw"]) if "quantiles" in which: quantities["quantile"] = azchain.quantile(q=quantiles, dim=["chain", "draw"]) if "hdi" in which: quantities["hdi"] = az.hdi(azchain, prob=hdi_prob) if "ess_bulk" in which: quantities["ess_bulk"] = az.ess(azchain, method="bulk") if "ess_tail" in which: quantities["ess_tail"] = az.ess(azchain, method="tail") if "mcse_mean" in which: quantities["mcse_mean"] = az.mcse(azchain, method="mean") if "mcse_sd" in which: quantities["mcse_sd"] = az.mcse(azchain, method="sd") if "rhat" in which and azchain.sizes["chain"] > 1: quantities["rhat"] = az.rhat(azchain) # convert to simple dict[str, np.ndarray] for key, val in quantities.items(): quantity = {} for k, v in val.data_vars.items(): if key == "hdi": remaining_dims = [dim for dim in v.dims if dim != "ci_bound"] v = v.transpose("ci_bound", *remaining_dims) quantity[k] = v.values quantities[key] = quantity return quantities
[docs] class SamplesSummary: """ Posterior summary and diagnostics for a dictionary of sample arrays. See :class:`.Summary` for the full description of the computed statistics, their interpretation, the ``quantities`` layout, and the behavior of ``quantiles``, ``hdi_prob``, ``per_chain``, and ``which``. This class computes the same sample-based statistics as :class:`.Summary`, but takes a plain dictionary of sample arrays instead of a :class:`.SamplingResults` object and does not include sampling-error or acceptance-probability diagnostics. Offers two main use cases: 1. The summary object can be turned into a :class:`~pandas.DataFrame` using :meth:`.to_dataframe`. 2. Programmatically access summary statistics via ``quantities[quantity_name][var_name]``. Please refer to the documentation of the attribute :attr:`.quantities` for details. Parameters ---------- samples The dictionary of samples to summarize. Each array is expected to have leading dimensions ``(nchains, ndraws, ...)``. quantiles Posterior quantile probabilities to compute when ``"quantiles"`` is included in ``which``. hdi_prob Posterior probability mass of the highest density interval to compute when ``"hdi"`` is included in ``which``. selected, deselected Allow to get a summary only for a subset of the position keys. per_chain If *True*, the summary is calculated on a per-chain basis. Certain measures like ``rhat`` are not available if ``per_chain`` is *True*. which Names of the summary statistics to compute. Supported values are the same as for :class:`.Summary`. Notes ----- This class is still considered experimental. The API may still undergo larger changes. """ config: dict def __init__( self, samples: dict[str, Array], quantiles: Sequence[float] = (0.05, 0.5, 0.95), hdi_prob: float = 0.9, selected: list[str] | None = None, deselected: list[str] | None = None, per_chain: bool = False, which: Sequence[SummaryQuantities] = summary_quantities, ): if not which: raise ValueError( f"Argument 'which' must not be empty. " f"Supported keys are: {summary_quantities}" ) for _which_key in which: if _which_key not in summary_quantities: raise ValueError( f"Key {which} in 'which' is not supported. " f"Supported keys are: {summary_quantities}" ) posterior_chain = Position(samples) if selected: posterior_chain = Position( { key: value for key, value in posterior_chain.items() if key in selected } ) if deselected is not None: for key in deselected: del posterior_chain[key] # get some general infos on the sampling param_chain = next(iter(posterior_chain.values())) sample_info = { "num_chains": param_chain.shape[0], "sample_size_per_chain": param_chain.shape[1], } # convert everything to numpy array for key in posterior_chain: posterior_chain[key] = np.asarray(posterior_chain[key]) # calculate quantiles either per chain and merge the results or all at once single_chain_summaries = [] if per_chain: for chain_idx in range(sample_info["num_chains"]): single_chain = slice_leaves( posterior_chain, jnp.s_[None, chain_idx, ...] ) qdict = _create_quantity_dict(single_chain, quantiles, hdi_prob, which) single_chain_summaries.append(qdict) quantities = stack_leaves(single_chain_summaries, axis=0) else: quantities = _create_quantity_dict( posterior_chain, quantiles, hdi_prob, which ) config = { "quantiles": quantiles, "hdi_prob": hdi_prob, "chains_merged": not per_chain, } self._which = which self.per_chain = per_chain self.quantities = quantities self.config = config self.sample_info = sample_info
[docs] @classmethod def from_array( cls, a: Array, quantiles: Sequence[float] = (0.05, 0.5, 0.95), hdi_prob: float = 0.9, selected: list[str] | None = None, deselected: list[str] | None = None, per_chain: bool = False, name: str = "v", which: Sequence[SummaryQuantities] = summary_quantities, ) -> SamplesSummary: """ Initializes the summary from an array of samples. Parameters ---------- a The array of samples to summarize. Expected to have leading dimensions ``(nchains, ndraws, ...)``. quantiles Posterior quantile probabilities to compute when ``"quantiles"`` is included in ``which``. hdi_prob Posterior probability mass of the highest density interval to compute when ``"hdi"`` is included in ``which``. selected, deselected Allow to get a summary only for a subset of the position keys. per_chain If *True*, the summary is calculated on a per-chain basis. Certain \ measures like ``rhat`` are not available if ``per_chain`` is *True*. name Variable name to use for labelling in :meth:`.to_dataframe`. which Names of the summary statistics to compute. Supported values are the same as for :class:`.Summary`. """ samples = {name: a} return cls(samples, quantiles, hdi_prob, selected, deselected, per_chain, which)
[docs] def to_dataframe(self) -> pd.DataFrame: """Turns SamplesSummary object into a :class:`~pandas.DataFrame` object.""" # don't change the original data quants = self.quantities.copy() # make new entries for the quantiles if self.per_chain: if "quantiles" in self._which: for i, q in enumerate(self.config["quantiles"]): quants[f"q_{q}"] = { k: v[:, i, ...] for k, v in quants["quantile"].items() } if "hdi" in self._which: quants["hdi_low"] = {k: v[:, 0, ...] for k, v in quants["hdi"].items()} quants["hdi_high"] = {k: v[:, 1, ...] for k, v in quants["hdi"].items()} else: if "quantiles" in self._which: for i, q in enumerate(self.config["quantiles"]): quants[f"q_{q}"] = { k: v[i, ...] for k, v in quants["quantile"].items() } if "hdi" in self._which: quants["hdi_low"] = {k: v[0, ...] for k, v in quants["hdi"].items()} quants["hdi_high"] = {k: v[1, ...] for k, v in quants["hdi"].items()} # remove the old entries if "hdi" in self._which: del quants["hdi"] if "quantiles" in self._which: del quants["quantile"] # create one row per entry df_dict = {} first_quant = list(quants.values())[0] for var in first_quant.keys(): it = np.nditer(first_quant[var], flags=["multi_index"]) for _ in it: var_fqn = ( var if len(it.multi_index) == 0 else f"{var}{list(it.multi_index)}" ) quant_per_elem: dict[str, Any] = {} quant_per_elem["variable"] = var if self.config["chains_merged"]: quant_per_elem["var_index"] = it.multi_index quant_per_elem["sample_size"] = ( self.sample_info["sample_size_per_chain"] * self.sample_info["num_chains"] ) else: quant_per_elem["chain_index"] = it.multi_index[0] quant_per_elem["var_index"] = it.multi_index[1:] quant_per_elem["sample_size"] = self.sample_info[ "sample_size_per_chain" ] for quant_name, quant_dict in quants.items(): quant_per_elem[quant_name] = quant_dict[var][it.multi_index] # convert jax.Arrays (scalar) to floats so that pandas treats them # correctly for key, val in quant_per_elem.items(): if isinstance(val, jax.Array): # value should be a scalar assert val.shape == () # replace dict element with value casted to float32 quant_per_elem[key] = float(val) df_dict[var_fqn] = quant_per_elem # convert to dataframe and use varname as index df = pd.DataFrame.from_dict(df_dict, orient="index") df = df.reset_index() df = df.rename(columns={"index": "var_fqn"}) df = df.set_index("variable") return df
def _param_df(self): df = self.to_dataframe() df.index.name = "parameter" df = df.rename(columns={"var_index": "index"}) df = df.set_index("index", append=True) qtls = [f"q_{qtl}" for qtl in self.config["quantiles"]] cols = ( ["kernel", "mean", "sd"] + qtls + ["sample_size", "ess_bulk", "ess_tail", "rhat"] ) cols = [col for col in cols if col in df.columns] df = df[cols] return df
[docs] def aggregate_diagnostics( self, by: Literal[ "min/max", "mean", "median", "std", "var", "min", "max" ] = "min/max", ) -> pd.DataFrame: """ Aggregates effective sample sizes (ESS) and rhat. Parameters ---------- by How to aggregate. The three current options are: - ``"min/max"``: Aggregate ESS by taking the minimum ESS per parameter block, and the rhat by taking the maximum rhat per parameter block. This corresponds to a worst-case summary. - ``"mean"``: Aggregate ESS and rhat by averaging inside parameter blocks. - ``"median"``: Aggregate ESS and rhat by taking the median inside parameter blocks. - ``"std"``: Aggregate ESS and rhat by taking the standard deviation inside parameter blocks. - ``"var"``: Aggregate ESS and rhat by taking the variance inside parameter blocks. - ``"min"``: Aggregate ESS and rhat by taking the minimum inside parameter blocks. - ``"max"``: Aggregate ESS and rhat by taking the maximum inside parameter blocks. Notes ------ If :attr:`.per_chain` is ``True``, rhat cannot be computed and there is not present in the output dataframe. """ df = self.to_dataframe() df.index.name = "parameter" if by == "min/max": by_ess = "min" by_rhat = "max" else: by_ess = by by_rhat = by if self.per_chain: diagnostics = ( df.loc[:, ["chain_index", "ess_bulk", "ess_tail"]] .groupby(["parameter", "chain_index"]) .agg( ess_bulk=("ess_bulk", by_ess), ess_tail=("ess_tail", by_ess), ) ) else: diagnostics = ( df.loc[:, ["rhat", "ess_bulk", "ess_tail"]] .groupby("parameter") .agg( ess_bulk=("ess_bulk", by_ess), ess_tail=("ess_tail", by_ess), rhat=("rhat", by_rhat), ) ) if by == "min/max": diagnostics["aggregated_by"] = "min (ess); max (rhat)" else: diagnostics["aggregated_by"] = by return diagnostics
def concatenate_arrays_in_dict( x: dict[str, jax.typing.ArrayLike], n_leading_axes: int = 2 ) -> jax.Array: """ Concatenates all arrays in the supplied dictionary into a single array. Returns ------- An array of dimension ``(leading1, leading1, nobs)``, where ``nobs`` is the total number of elements in the non-leading axes of the dictionary values. As the default case, we expect the output to have shape ``(nchains, nsamples, nobs)``. """ flat_arrays = [] for v in x.values(): # assumed to have shape (s, c, ...) v = jnp.atleast_3d(jnp.asarray(v)) vshape = v.shape[:n_leading_axes] + (-1,) flat_arrays.append(jnp.reshape(v, vshape)) out_array = jnp.concatenate(flat_arrays, axis=-1) return out_array def _apply_loo_scale( result: az.ELPDData, scale: Literal["log", "negative_log", "deviance"], ) -> az.ELPDData: if scale != "log": multiplier = -1 if scale == "negative_log" else -2 result.elpd = multiplier * result.elpd result.se = abs(multiplier) * result.se result.elpd_i = multiplier * result.elpd_i result.scale = scale result.elpd_loo = result.elpd result.p_loo = result.p return result
[docs] def loo( lpp: dict[str, jax.typing.ArrayLike] | jax.typing.ArrayLike, samples: dict[str, jax.typing.ArrayLike] | None, reff: float | None = None, scale: Literal["log", "negative_log", "deviance"] = "log", ) -> az.ELPDData: """ Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) statistic via ArviZ. Parameters ---------- lpp Dictionary or array of pointwise log probability evaluations. If passed as a dictionary, each value is expected to have shape ``(nsamples, nchains, ...)``. If passed as an array, it is assumed to have shape ``(nsamples, nchains, n)``. samples Dictionary of samples at which to evaluate log probs. If ``samples`` contains entries for weak variables or for nodes in :attr:`.model_nodes` they are ignored. newdata Dictionary of new data at which to evaluate log probs. The keys should \ correspond to variable or node names in the model whose values should be \ set to the given values before evaluating predictions. If ``None`` \ (default), the current variable values are used. reff Relative MCMC efficiency, ess / n i.e. number of effective samples divided by the number of actual samples. Computed from the samples by default. scale Output scale. The options are: - ``log``: (default) log probability scale. - ``negative_log``: ``-1 * log`` - ``deviance``: ``-2 * log`` A higher log probability (or a lower deviance or negative log_score) indicates a model with better predictive accuracy. References ---------- - Computations are carried out via ArviZ: https://python.arviz.org/en/stable/ - Theoretical background: Vehtari, A., Gelman, A., & Gabry, J. (2017). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing, 27(5), 1413–1432. https://doi.org/10.1007/s11222-016-9696-4 """ if samples is None and reff is None: raise ValueError( "Both 'samples' and 'reff' are None, so relative MCMC efficiency is not " "available." ) try: lpp_array = jnp.asarray(lpp) except Exception: # assume its a dict now lpp_array = concatenate_arrays_in_dict(lpp) lpp_array = np.asarray(lpp_array) idat = az.from_dict({"log_likelihood": {"observed": lpp_array}}) if reff is None and samples is not None: avg_ess = ( SamplesSummary(samples, which=["ess_bulk"]) .to_dataframe()["ess_bulk"] .mean() ) nsamples = lpp_array.shape[0] * lpp_array.shape[1] reff = avg_ess / nsamples # now we assume reff is not None return _apply_loo_scale(az.loo(idat, reff=reff), scale)