Source code for liesel.goose.summary_viz

"""
Diagnostic plots of the posterior samples.
"""

from collections.abc import Sequence
from typing import Any

import arviz
import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from liesel.goose.engine import SamplingResults


def _raise_chain_indices_error(
    chain_indices: Sequence[int], num_original_chains: int
) -> None:
    """Display informative error message with valid ``chain_indices`` inputs."""
    if any(
        chain_index not in range(num_original_chains) for chain_index in chain_indices
    ):
        raise ValueError(
            f"All chain indices must be between 0 and {num_original_chains-1} "
            "(bounds inclusive)."
        )


def _validate_chain_indices(
    chain_indices: int | Sequence[int] | None,
    num_original_chains: int,
) -> Sequence[int]:
    """
    Convert ``int`` or ``None`` input of ``chain_indices`` to sequence of integers.
    """
    if chain_indices is None:
        return list(range(num_original_chains))

    if isinstance(chain_indices, int):
        chain_indices = [chain_indices]

    _raise_chain_indices_error(chain_indices, num_original_chains)
    return chain_indices


def _raise_dimension_error(param: str, num_dim: int) -> None:
    """Check for correct array dimensions of posterior samples."""
    if num_dim not in (2, 3):
        raise ValueError(
            f"Array of posterior samples for {param} has the wrong number of"
            f"dimensions.\nExpected 2 or 3, got {num_dim}."
        )


def _adjust_dimensions(param_chains: np.ndarray, num_dim: int) -> np.ndarray:
    """
    Make shape of posterior samples for one dimensional parameters (e.g. ``log_sigma``)
    consistent with multi-dimensional parameters.
    """
    if num_dim == 2:
        param_chains = np.expand_dims(param_chains, axis=-1)
    return param_chains


def _raise_param_indices_error(
    param_indices: Sequence[int], num_original_subparams: int, param: str
) -> None:
    """
    Display informative error message with valid ``param_indices`` inputs for this
    specific ``param``.
    """
    if any(
        param_index not in range(num_original_subparams)
        for param_index in param_indices
    ):
        raise ValueError(
            f"All param indices for {param} must be between "
            f"0 and {num_original_subparams-1} (bounds inclusive)."
        )


def _validate_param_indices(
    param_indices: int | Sequence[int] | None, num_original_subparams: int, param: str
) -> Sequence[int]:
    """
    Convert ``int`` or ``None`` input of ``param_indices`` to sequence of integers.
    """
    if param_indices is None:
        return list(range(num_original_subparams))

    if isinstance(param_indices, int):
        param_indices = [param_indices]

    _raise_param_indices_error(param_indices, num_original_subparams, param)

    return param_indices


def _move_col_first(df: pd.DataFrame, colname: str) -> pd.DataFrame:
    """Move last column of a :class:`~pandas.DataFrame` to the first column."""
    return df[[colname] + [col for col in df.columns if col != colname]]


def _validate_params(
    posterior_samples: dict[str, jax.Array], params: str | list[str] | None
) -> list[str]:
    """Convert ``str`` or ``None`` input of ``params`` to sequence of strings."""
    posterior_keys = list(posterior_samples.keys())
    if params is None:
        return posterior_keys

    if isinstance(params, str):
        params = [params]

    if any(param not in posterior_keys for param in params):
        raise KeyError(f"All params must be in {posterior_keys}.")

    return params


def _subparam_chains_to_df(
    subparam_chains: np.ndarray, param_index: int
) -> pd.DataFrame:
    """
    Convert array of posterior samples for a single subparameter (e.g. ``beta[0]``) to a
    pandas data frame.
    """

    subparam_df = (
        pd.DataFrame(subparam_chains)
        .melt(ignore_index=False)
        .rename(columns={"variable": "iteration"})
        .reset_index()
        .rename(columns={"index": "chain_index"})
        .sort_values(by=["chain_index", "iteration"], ignore_index=True)
        .assign(param_index=param_index)
    )

    return _move_col_first(subparam_df, colname="param_index")


def _preprocess_param_chains(
    posterior_samples: dict[str, jax.Array], param: str
) -> np.ndarray:
    """Convert array of posteror samples for each parameter to equal dimensions."""

    param_chains = np.array(posterior_samples[param])
    num_dim = param_chains.ndim

    _raise_dimension_error(param, num_dim)
    param_chains = _adjust_dimensions(param_chains, num_dim)
    return param_chains


def _convert_to_sequence(indices: int | Sequence[int]) -> Sequence[int]:
    """Convert integer parameter and chain indices to list or tuple."""

    if isinstance(indices, int):
        indices = [indices]

    return indices


def _filter_param_df(
    param_df: pd.DataFrame,
    param_indices: int | Sequence[int] | None,
    chain_indices: int | Sequence[int] | None,
    max_chains: int | None,
) -> pd.DataFrame:
    """
    Filters the plotting data to contain only the specified parameter and chain indices.
    Output data contains not more than `max_chains` chain indices.
    """

    if chain_indices is not None:
        chain_indices = _convert_to_sequence(chain_indices)
        param_df = param_df.loc[param_df["chain_index"].isin(chain_indices)]

    if param_indices is not None:
        param_indices = _convert_to_sequence(param_indices)
        param_df = param_df.loc[param_df["param_index"].isin(param_indices)]

    if max_chains is not None and param_df["chain_index"].nunique() > max_chains:
        last_chain_index = sorted(param_df["chain_index"].unique())[max_chains - 1]
        param_df = param_df.loc[param_df["chain_index"] <= last_chain_index]

    return param_df


def _postprocess_param_df(
    param_df: pd.DataFrame,
    param: str,
    param_indices: int | Sequence[int] | None,
    chain_indices: int | Sequence[int] | None,
    max_chains: int | None,
) -> pd.DataFrame:
    """
    Combines parameter and parameter index column and filters chain and parameter
    indices.
    """

    num_original_chains = param_df["chain_index"].nunique()
    num_original_subparams = param_df["param_index"].nunique()

    if num_original_subparams > 1:
        name = param_df["param"].astype(str)
        index = param_df["param_index"].astype(str)
        param_df.loc[:, "param_label"] = name + "[" + index + "]"
    else:
        name = param_df["param"].astype(str)
        param_df["param_label"] = name

    chain_indices = _validate_chain_indices(chain_indices, num_original_chains)
    param_indices = _validate_param_indices(
        param_indices, num_original_subparams, param
    )
    param_df = _filter_param_df(param_df, param_indices, chain_indices, max_chains)

    return param_df


def _collect_subparam_dfs(
    posterior_samples: dict[str, jax.Array],
    param: str,
    param_indices: int | Sequence[int] | None,
    chain_indices: int | Sequence[int] | None,
    max_chains: int | None,
) -> pd.DataFrame:
    """
    Combines individual data frames for each subparameter into a single data frame for
    each parameter.
    """

    param_chains = _preprocess_param_chains(posterior_samples, param)

    param_df = (
        pd.concat(
            [
                _subparam_chains_to_df(param_chains[..., param_index], param_index)
                for param_index in range(param_chains.shape[-1])
            ]
        )
        .assign(param=param)
        .reset_index(drop=True)
    )

    param_df = _move_col_first(param_df, colname="param")

    return _postprocess_param_df(
        param_df, param, param_indices, chain_indices, max_chains
    )


def _collect_param_dfs(
    results: SamplingResults,
    params: str | list[str] | None = None,
    param_indices: int | Sequence[int] | None = None,
    chain_indices: int | Sequence[int] | None = None,
    max_chains: int | None = 5,
    include_warmup: bool = False,
) -> pd.DataFrame:
    """Combines individual data frames for each parameter into a single data frame."""

    if include_warmup:
        samples = results.get_samples()
    else:
        samples = results.get_posterior_samples()
    params = _validate_params(samples, params)

    return pd.concat(
        [
            _collect_subparam_dfs(
                samples, param, param_indices, chain_indices, max_chains
            )
            for param in params
        ]
    ).reset_index(drop=True)


def _setup_plot_df(
    results: SamplingResults,
    params: str | list[str] | None,
    param_indices: int | Sequence[int] | None,
    chain_indices: int | Sequence[int] | None,
    max_chains: int | None,
    include_warmup: bool = False,
) -> pd.DataFrame:
    """Provides data input for all plotting functions."""

    return _collect_param_dfs(
        results,
        params,
        param_indices,
        chain_indices,
        max_chains,
        include_warmup,
    ).astype({"chain_index": "category"})


def _setup_scatterplot_df(
    results: SamplingResults,
    params: str | list[str] | None,
    param_indices: int | Sequence[int] | None,
    chain_indices: int | Sequence[int] | None,
    max_chains: int | None,
    include_warmup: bool = False,
) -> pd.DataFrame:
    """
    Provides bespoke data input for plot_scatter. If two indices *and* two params are
    specified, the first index refers the first param and the second index to the
    second.
    """
    if (
        isinstance(params, str)
        or params is None
        or isinstance(param_indices, int)
        or param_indices is None
    ):
        return _setup_plot_df(
            results, params, param_indices, chain_indices, max_chains, include_warmup
        )
    if not isinstance(params, str) and len(params) > 1:
        param0_df = _setup_plot_df(
            results,
            params[0],
            param_indices[0],
            chain_indices,
            max_chains,
            include_warmup,
        )
        param1_df = _setup_plot_df(
            results,
            params[1],
            param_indices[1],
            chain_indices,
            max_chains,
            include_warmup,
        )
        plot_df = pd.concat([param0_df, param1_df], ignore_index=True)
    else:
        plot_df = _setup_plot_df(
            results, params, param_indices, chain_indices, max_chains, include_warmup
        )
    return plot_df


def _set_plot_cols(plot_df: pd.DataFrame, ncol: int) -> int:
    """Determines number of facets within each row of the grid."""

    num_subparams = plot_df["param_label"].nunique()
    return min(ncol, num_subparams)


def _set_aesthetics(
    g: sns.FacetGrid, title: str | None, title_spacing: float, xlabel: str, ylabel: str
) -> sns.FacetGrid:
    """Adds titles, labels and correct spacing between facets."""

    g.set_titles(col_template="{col_name}")
    g.set_axis_labels(x_var=xlabel, y_var=ylabel)
    g.tight_layout()

    g.legend.set_title("Chain")

    if title is not None:
        g.fig.suptitle(title)
        g.fig.subplots_adjust(top=title_spacing)

    return g


def save_figure(g: sns.FacetGrid | None = None, save_path: str | None = None) -> None:
    """Saves plot to file."""

    if save_path is not None:
        if g is not None:
            g.fig.savefig(save_path)
        else:
            plt.savefig(save_path)
    else:
        plt.show()


[docs] def plot_trace( results: SamplingResults, params: str | list[str] | None = None, param_indices: int | Sequence[int] | None = None, chain_indices: int | Sequence[int] | None = None, max_chains: int | None = 5, title: str | None = None, title_spacing: float = 0.85, xlabel: str = "Iteration", style: str = "whitegrid", color_palette: str | list[str] | dict[int, str] | None = None, ncol: int = 3, height: int = 3, aspect_ratio: int = 1, save_path: str | None = None, include_warmup: bool = False, **kwargs, ) -> sns.FacetGrid: """ Visualizes posterior samples over time with a trace plot. Parameters ---------- results Result object of the sampling process. Must have a method ``get_posterior_samples()`` which extracts all samples from the posterior distribution. params Names of the model parameters that are contained in the plot. Must coincide with the dictionary keys of the `Position` with the posterior samples. If `None`, all parameters are included. param_indices Indices of each model parameter that are contained in the plot. Selects e.g. ``beta[0]`` out of a ``beta`` parameter vector. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all subparameters are included. chain_indices Indices of chains for each model subparameter that are contained in the plot. Selects e.g. chain 0 and chain 2 out of multiple chains. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all chains are included. max_chains Upper bound how many chains are included within each subplot/facet. Avoids overplotting. If ``None``, all chains contained in the ``results`` input are plotted. Always starts chain selection from the lowest chain index upwards. For selecting specific chains use the argument ``chain_indices``. title Plot title. title_spacing Determines the margin/whitespace between the plot title (set with ``fig.suptitle()``) and the first row of subplots/facets. Passed to the ``top`` argument of ``fig.subplots_adjust()``. xlabel Label of the x-axis. style Passed to the ``style`` argument of ``sns.set_theme()``. Valid options are ``"darkgrid"``, ``"whitegrid"``, ``"dark"``, ``"white"``, and ``"ticks"``. color_palette Passed to the palette argument of ``sns.relplot()``. String values must be valid inputs of ``sns.color_palette()`` such as a seaborn color palette or a matplotlib colormap. Custom colors can be set with a list of color strings or a dictionary with the chain indices as keys and color strings as values. The number of color strings must coincide with the number of plotted chains. If ``None``, the default ``tab10`` matplotlib colormap is chosen. ncol Number of subplots/facets within each row of the grid. height Height in inches of each subplot/facet within the grid. aspect_ratio Ratio of width / height of each subplot/facet within the grid, i.e. ``width = aspect_ratio * height``. save_path File path where the plot is saved. include_warmup Include the warmup samples in the trace plot. **kwargs Further keyword arguments passed to the seaborn ``relplot()`` function. Returns ------- A seaborn ``FacetGrid``. """ # NOTE: Docstring duplications # The entries `results` to `max_chains` are shared with the `summary()` # and all user plotting functions. # The entries `title` to `save_path` are shared with `plot_density()`, `plot_cor()` # and partially with `plot_param()`. sns.set_theme(style=style) plot_df = _setup_plot_df( results, params, param_indices, chain_indices, max_chains, include_warmup ) g = sns.relplot( data=plot_df, kind="line", x="iteration", y="value", hue="chain_index", col="param_label", col_wrap=_set_plot_cols(plot_df, ncol), facet_kws=dict(sharex=True, sharey=False), palette=color_palette, height=height, aspect=aspect_ratio, **kwargs, ) g = _set_aesthetics(g, title, title_spacing, xlabel, ylabel="") save_figure(g, save_path) return g
[docs] def plot_density( results: SamplingResults, params: str | list[str] | None = None, param_indices: int | Sequence[int] | None = None, chain_indices: int | Sequence[int] | None = None, max_chains: int | None = 5, title: str | None = None, title_spacing: float = 0.85, xlabel: str = "Value", style: str = "whitegrid", color_palette: str | list[str] | dict[int, str] | None = None, ncol: int = 3, height: int = 3, aspect_ratio: int = 1, save_path: str | None = None, **kwargs, ) -> sns.FacetGrid: """ Visualizes posterior distributions with a density plot. Parameters ---------- results Result object of the sampling process. Must have a method ``get_posterior_samples()`` which extracts all samples from the posterior distribution. params Names of the model parameters that are contained in the plot. Must coincide with the dictionary keys of the ``Position`` with the posterior samples. If ``None``, all parameters are included. param_indices Indices of each model parameter that are contained in the plot. Selects e.g. ``beta[0]`` out of a ``beta`` parameter vector. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all subparameters are included. chain_indices Indices of chains for each model subparameter that are contained in the plot. Selects e.g. chain 0 and chain 2 out of multiple chains. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all chains are included. max_chains Upper bound how many chains are included within each subplot/facet. Avoids overplotting. If ``None``, all chains contained in the ``results`` input are plotted. Always starts chain selection from the lowest chain index upwards. For selecting specific chains use the argument ``chain_indices``. title Plot title. title_spacing Determines the margin/whitespace between the plot title (set with ``fig.suptitle()``) and the first row of subplots/facets. Passed to the ``top`` argument of ``fig.subplots_adjust()``. xlabel Label of the x-axis. style Passed to the ``style`` argument of ``sns.set_theme()``. Valid options are ``darkgrid``, ``whitegrid``, ``dark``, ``white``, and ``ticks``. color_palette Passed to the palette argument of ``sns.displot()``. String values must be valid inputs of ``sns.color_palette()`` such as a seaborn color palette or a matplotlib colormap. Custom colors can be set with a list of color strings or a dictionary with the chain indices as keys and color strings as values. The number of color strings must coincide with the number of plotted chains. If ``None``, the default ``tab10`` matplotlib colormap is chosen. ncol Number of subplots/facets within each row of the grid. height Height in inches of each subplot/facet within the grid. aspect_ratio Ratio of width / height of each subplot/facet within the grid, i.e. ``width = aspect_ratio * height``. save_path File path where the plot is saved. **kwargs Further keyword arguments passed to the seaborn ``displot()`` function. Returns ------- A seaborn ``FacetGrid``. """ # NOTE: Docstring duplications # The entries `results` to `max_chains` are shared with the `summary()` # and all user plotting functions. # The entries `title` to `save_path` are shared with `plot_trace()`, `plot_cor()` # and partially with `plot_param()`. sns.set_theme(style=style) plot_df = _setup_plot_df(results, params, param_indices, chain_indices, max_chains) g = sns.displot( data=plot_df, kind="kde", x="value", y=None, hue="chain_index", col="param_label", col_wrap=_set_plot_cols(plot_df, ncol), facet_kws=dict(sharex=False, sharey=False), palette=color_palette, height=height, aspect=aspect_ratio, **kwargs, ) g = _set_aesthetics(g, title, title_spacing, xlabel, ylabel="") save_figure(g, save_path) return g
def _compute_max_lags( plot_df: pd.DataFrame, max_lags: int | None, ) -> int: """ Determines number time lags that are shown on the x-axis of the autocorrelation plot. """ num_iterations = plot_df["iteration"].max() max_lags = np.min([num_iterations, 30]) if max_lags is None else max_lags return max_lags
[docs] def plot_cor( results: SamplingResults, params: str | list[str] | None = None, param_indices: int | Sequence[int] | None = None, chain_indices: int | Sequence[int] | None = None, max_chains: int | None = 5, max_lags: int | None = None, title: str | None = None, title_spacing: float = 0.85, xlabel: str = "Lag", style: str = "whitegrid", color_palette: str | list[str] | dict[int, str] | None = None, ncol: int = 3, height: int = 3, aspect_ratio: int = 1, save_path: str | None = None, **kwargs, ) -> sns.FacetGrid: """ Visualizes autocorrelations of posterior samples. Parameters ---------- results Result object of the sampling process. Must have a method ``get_posterior_samples()`` which extracts all samples from the posterior distribution. params Names of the model parameters that are contained in the plot. Must coincide with the dictionary keys of the ``Position`` with the posterior samples. If ``None``, all parameters are included. param_indices Indices of each model parameter that are contained in the plot. Selects e.g. ``beta[0]`` out of a ``beta`` parameter vector. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all subparameters are included. chain_indices Indices of chains for each model subparameter that are contained in the plot. Selects e.g. chain 0 and chain 2 out of multiple chains. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all chains are included. max_chains Upper bound how many chains are included within each subplot/facet. Avoids overplotting. If ``None``, all chains contained in the ``results`` input are plotted. Always starts chain selection from the lowest chain index upwards. For selecting specific chains use the argument ``chain_indices``. max_lags Maximum number of time lags shown on the x-axis of the autocorrelation plot. If ``None``, the minimum of the chain lengths and 30 is chosen. title Plot title. title_spacing Determines the margin/whitespace between the plot title (set with ``fig.suptitle()``) and the first row of subplots/facets. Passed to the ``top`` argument of ``fig.subplots_adjust()``. xlabel Label of the x-axis. style Passed to the ``style`` argument of ``sns.set_theme()``. Valid options are ``darkgrid``, ``whitegrid``, ``dark``, ``white``, and ``ticks``. color_palette Passed to the palette argument of ``sns.FacetGrid()``. String values must be valid inputs of ``sns.color_palette()`` such as a seaborn color palette or a matplotlib colormap. Custom colors can be set with a list of color strings or a dictionary with the chain indices as keys and color strings as values. The number of color strings must coincide with the number of plotted chains. If ``None``, the default ``tab10`` matplotlib colormap is chosen. ncol Number of subplots/facets within each row of the grid. height Height in inches of each subplot/facet within the grid. aspect_ratio Ratio of width / height of each subplot/facet within the grid, i.e. ``width = aspect_ratio * height``. save_path File path where the plot is saved. **kwargs Further keyword arguments passed to the seaborn ``FacetGrid()`` function. Returns ------- A seaborn ``FacetGrid``. """ # NOTE: Docstring duplications # The entries `results` to `max_chains` are shared with the `summary()` # and all user plotting functions. # The entries `title` to `save_path` are shared with `plot_trace()`, # `plot_density()` and partially with `plot_param()`. # The entry `max_lags` is shared with `plot_param()`. sns.set_theme(style=style) plot_df = _setup_plot_df(results, params, param_indices, chain_indices, max_chains) max_lags = _compute_max_lags(plot_df, max_lags) def do_acor_plot(x, maxlags, **kwargs): x = np.asarray(x) acor = arviz.autocorr(x)[..., 0:maxlags] return sns.lineplot(x=range(maxlags), y=acor, **kwargs) g = ( sns.FacetGrid( data=plot_df, hue="chain_index", col="param_label", col_wrap=_set_plot_cols(plot_df, ncol), palette=color_palette, height=height, aspect=aspect_ratio, **kwargs, ) .map( do_acor_plot, "value", maxlags=max_lags, ) .set( xlim=(0, max_lags), ylim=(-0.2, 1.1), ) .add_legend() ) g = _set_aesthetics(g, title, title_spacing, xlabel, ylabel="Autocorrelation") save_figure(g, save_path) return g
def _raise_multi_param_error(plot_df: pd.DataFrame, param: str) -> None: """ :func:`.plot_param` function can only display all three diagnostic plots for a single subparameter. Throws an informative error otherwise. """ if plot_df["param_label"].nunique() > 1: raise ValueError( f"{param} has more than one index. " "Please specify a single `param_index` for plotting." ) def _set_colors(plot_df: pd.DataFrame, color_list: list[str] | None) -> list[str]: """Determines colors of different chains in each plot.""" num_chains = plot_df["chain_index"].nunique() if color_list is None: # default matplotlib and seaborn colors with 10 elements color_list = sns.color_palette() # make default color list sufficiently long if num_chains > 10: color_list = color_list * (num_chains // 10 + 1) color_list = color_list[:num_chains] # type: ignore return color_list def _setup_grid( figure_size: tuple[int | float, int | float] ) -> tuple[plt.Figure, Any, Any, Any]: """ Initializes plotting grid with one large subplot for the trace plot and two smaller subplots for the density and autocorrelation plot. """ fig = plt.figure(figsize=figure_size) ax1 = plt.subplot2grid(shape=(2, 2), loc=(0, 0), colspan=2) ax2 = plt.subplot2grid(shape=(2, 2), loc=(1, 0)) ax3 = plt.subplot2grid(shape=(2, 2), loc=(1, 1)) return fig, ax1, ax2, ax3 def _add_lineplot(plot_df: pd.DataFrame, ax: Any, color_list: list[str]) -> None: """Adds trace plot to plotting grid.""" sns.lineplot( data=plot_df, x="iteration", y="value", hue="chain_index", palette=color_list, ax=ax, legend="full", ).set(xlabel="Iteration", ylabel="") def _add_kdeplot(plot_df: pd.DataFrame, ax: Any, color_list: list[str]) -> None: """Adds density plot to plotting grid.""" sns.kdeplot( data=plot_df, x="value", hue="chain_index", palette=color_list, ax=ax, legend=False, ).set(xlabel="Value", ylabel="") def _add_corplot( plot_df: pd.DataFrame, ax: Any, max_lags: int | None, color_list: list[str] ) -> None: """Adds correlation plot to plotting grid.""" max_lags = _compute_max_lags(plot_df, max_lags) for chain_index, col in zip(plot_df["chain_index"].unique(), color_list): x = np.asarray(plot_df.loc[plot_df["chain_index"] == chain_index]["value"]) acor = arviz.autocorr(x)[0:max_lags] sns.lineplot( x=range(max_lags), y=acor, marker="", linestyle="-", color=col, ax=ax, ) ax.set(xlim=(0, max_lags), ylim=(-0.2, 1.1), xlabel="Lag", ylabel="") def _get_title(plot_df: pd.DataFrame, title: str | None) -> str: """Sets either a custom or the default plot title.""" default_title = f"Diagnostic plots for '{plot_df['param'][0]}'" return title if title is not None else default_title
[docs] def plot_param( results: SamplingResults, param: str, param_index: int | None = None, chain_indices: int | Sequence[int] | None = None, max_chains: int | None = 5, max_lags: int | None = None, title: str | None = None, title_spacing: float = 0.9, style: str = "whitegrid", color_list: list[str] | None = None, figure_size: tuple[int | float, int | float] = (9, 6), # default values chosen for default figure size of (9, 6) legend_position: tuple[float, float] = (1.2, 0.4), save_path: str | None = None, ) -> None: """ Visualizes trace plot, density plot and autocorrelation plot of a single subparameter. Parameters ---------- results Result object of the sampling process. Must have a method ``get_posterior_samples()`` which extracts all samples from the posterior distribution. param Name of a single model parameter that is contained in the plot. Must coincide with one dictionary key of the ``Position`` with the posterior samples. param_index A single index of the selected model parameter that is contained in the plot. Selects e.g. ``beta[0]`` out of a ``beta`` parameter vector. Can be specified as an integer or as a sequence containing one integer. If ``None``, the parameter is assumed to have only a single index. chain_indices Indices of chains for each model subparameter that are contained in the plot. Selects e.g. chain 0 and chain 2 out of multiple chains. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all chains are included. max_chains Upper bound how many chains are included within each subplot/facet. Avoids overplotting. If ``None``, all chains contained in the ``results`` input are plotted. Always starts chain selection from the lowest chain index upwards. For selecting specific chains use the argument ``chain_indices``. max_lags Maximum number of time lags shown on the x-axis of the autocorrelation plot. If ``None``, the minimum of the chain lengths and 30 is chosen. title Plot title. title_spacing Determines the margin/whitespace between the plot title (set with ``fig.suptitle()``) and the first row of subplots/facets. Passed to the ``top`` argument of ``fig.subplots_adjust()``. style Passed to the ``style`` argument of ``sns.set_theme()``. Valid options are ``darkgrid``, ``whitegrid``, ``dark``, ``white``, and ``ticks``. color_list Determines the chain colors for all three subplots. Custom colors can be passed with a list of color strings. The length of the list must match the number of chains. If ``None``, the default ``tab10`` matplotlib colormap is chosen. figure_size Size of the entire plot grid. Passed to the ``figsize`` argument of ``plt.figure()``. When changing the figure size consider changing the ``legend_position`` as well. Generally, a ratio of 3 legend_position Determines the color legend position. Coordinates are relative to the upper panel within the plot grid. The first coordinate specifies the horizontal, the second coordinate the vertical position. Might require an adjustment when changing the ``figure_size`` values or the number of chains. save_path File path where the plot is saved. """ # NOTE: Docstring duplications # The entries `results`, `chain_indices` and `max_chains` are shared with the # `summary()` and all user plotting functions. # The entries `title`, `title_spacing`, `style` and `save_path` are shared with # `plot_trace()`, `plot_density()` and `plot_cor()`. # The entry `max_lags` is shared with `plot_cor()`. sns.set_theme(style=style) plot_df = _setup_plot_df(results, param, param_index, chain_indices, max_chains) _raise_multi_param_error(plot_df, param) color_list = _set_colors(plot_df, color_list) fig, ax1, ax2, ax3 = _setup_grid(figure_size) _add_lineplot(plot_df, ax1, color_list) _add_kdeplot(plot_df, ax2, color_list) _add_corplot(plot_df, ax3, max_lags, color_list) ax1.legend(title="Chain", bbox_to_anchor=legend_position, frameon=False) fig.tight_layout() fig.suptitle(_get_title(plot_df, title)) fig.subplots_adjust(top=title_spacing) save_figure(save_path=save_path)
[docs] def plot_scatter( results: SamplingResults, params: list[str], param_indices: tuple[int, int], chain_indices: int | Sequence[int] | None = None, max_chains: int | None = 5, alpha: float = 0.2, title: str | None = None, title_spacing: float = 0.9, style: str = "whitegrid", color_list: list[str] | None = None, figure_size: tuple[int | float, int | float] = (9, 6), legend_position: tuple[float, float] | str = "best", save_path: str | None = None, include_warmup: bool = False, ): """ Produces a scatterplot of two parameters. Parameters ---------- results Result object of the sampling process. Must have a method ``get_posterior_samples()`` which extracts all samples from the posterior distribution. params Names of the model parameters that are contained in the plot. Must coincide with the dictionary keys of the ``Position`` with the posterior samples. param_indices Indices of each model parameter that are contained in the plot. Selects e.g. ``beta[0]`` out of a ``beta`` parameter vector. If only one string is supplied as the value of ``params``, ``param_indices`` must contain two indices. If a sequence of two strings is supplied to ``params``, you can supply either a single integer or a tuple of two integers. A single integer will be used as the index for *both* parameters. If you use a tuple of two integers, the first element will be used as the index for the first parameter, and the second element will be used as the index for the second parameter. chain_indices Indices of chains for each model subparameter that are contained in the plot. Selects e.g. chain 0 and chain 2 out of multiple chains. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all chains are included. max_chains Upper bound how many chains are included within each subplot/facet. Avoids overplotting. If ``None``, all chains contained in the ``results`` input are plotted. Always starts chain selection from the lowest chain index upwards. For selecting specific chains use the argument ``chain_indices``. alpha Amount of transparency; a float between 0 and 1. title Plot title. title_spacing Determines the margin/whitespace between the plot title (set with ``fig.suptitle()``) and the first row of subplots/facets. Passed to the ``top`` argument of ``fig.subplots_adjust()``. style Passed to the ``style`` argument of ``sns.set_theme()``. Valid options are ``darkgrid``, ``whitegrid``, ``dark``, ``white``, and ``ticks``. color_list Determines the chain colors for all three subplots. Custom colors can be passed with a list of color strings. The length of the list must match the number of chains. If ``None``, the default ``tab10`` matplotlib colormap is chosen. figure_size Size of the entire plot grid. Passed to the ``figsize`` argument of ``plt.figure()``. When changing the figure size consider changing the ``legend_position`` as well. Generally, a ratio of 3 legend_position Determines the color legend position. Coordinates are relative to the upper panel within the plot grid. The first coordinate specifies the horizontal, the second coordinate the vertical position. Might require an adjustment when changing the ``figure_size`` values or the number of chains. save_path File path where the plot is saved. """ # NOTE: Docstring duplications # Multiple arguments in this docstring are shared with other plotting functions. sns.set_theme(style=style) plot_df = _setup_scatterplot_df( results, params, param_indices, chain_indices, max_chains, include_warmup ) labels = plot_df.param_label.unique() if len(labels) != 2: raise ValueError( "'plot_scatter' can only plot exactly two parameters. Use 'plot_pairs'" " instead to plot more." ) plot_df = ( plot_df.drop(["param_index", "param"], axis=1) .pivot( index=["chain_index", "iteration"], columns="param_label", values="value" ) .reset_index() .drop(["iteration"], axis=1) ) color_list = _set_colors(plot_df, color_list) fig, axis = plt.subplots(1, 1, figsize=figure_size) sns.scatterplot( data=plot_df, x=labels[0], y=labels[1], alpha=alpha, hue="chain_index", palette=color_list, ax=axis, ) if title is not None: fig.suptitle(title) fig.subplots_adjust(top=title_spacing) axis.legend(title="Chain", loc=legend_position, frameon=False) save_figure(save_path=save_path)
[docs] def plot_pairs( results: SamplingResults, params: str | list[str] | None = None, param_indices: int | Sequence[int] | None = None, chain_indices: int | Sequence[int] | None = None, max_chains: int | None = 5, alpha: float = 0.2, title: str | None = None, title_spacing: float = 0.9, style: str = "whitegrid", diag_kind: str = "kde", color_palette: str | list[str] | dict[int, str] | None = None, height: int = 3, aspect_ratio: int = 1, save_path: str | None = None, include_warmup: bool = False, ): """ Produces a pairplot panel. Parameters ---------- results Result object of the sampling process. Must have a method ``get_posterior_samples()`` which extracts all samples from the posterior distribution. params Names of the model parameters that are contained in the plot. Must coincide with the dictionary keys of the ``Position`` with the posterior samples. If ``None``, all parameters are included. param_indices Indices of each model parameter that are contained in the plot. Selects e.g. ``beta[0]`` out of a ``beta`` parameter vector. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all subparameters are included. chain_indices Indices of chains for each model subparameter that are contained in the plot. Selects e.g. chain 0 and chain 2 out of multiple chains. A single index can be specified as an integer or a sequence containing one integer. If ``None``, all chains are included. max_chains Upper bound how many chains are included within each subplot/facet. Avoids overplotting. If ``None``, all chains contained in the ``results`` input are plotted. Always starts chain selection from the lowest chain index upwards. For selecting specific chains use the argument ``chain_indices``. alpha Amount of transparency; a float between 0 and 1. title Plot title. title_spacing Determines the margin/whitespace between the plot title (set with ``fig.suptitle()``) and the first row of subplots/facets. Passed to the ``top`` argument of ``fig.subplots_adjust()``. style Passed to the ``style`` argument of ``sns.set_theme()``. Valid options are ``darkgrid``, ``whitegrid``, ``dark``, ``white``, and ``ticks``. diag_kind Kind of plot for the diagonal subplots. Can be 'kde' (default) for kernel density estimates or 'hist' for histograms. color_palette Passed to the palette argument of ``sns.pairplot()``. String values must be valid inputs of ``sns.color_palette()`` such as a seaborn color palette or a matplotlib colormap. Custom colors can be set with a list of color strings or a dictionary with the chain indices as keys and color strings as values. The number of color strings must coincide with the number of plotted chains. If ``None``, the default ``tab10`` matplotlib colormap is chosen. height Height in inches of each subplot/facet within the grid. aspect_ratio Ratio of width / height of each subplot/facet within the grid, i.e. ``width = aspect_ratio * height``. legend_position Determines the color legend position. Coordinates are relative to the upper panel within the plot grid. The first coordinate specifies the horizontal, the second coordinate the vertical position. Might require an adjustment when changing the ``figure_size`` values or the number of chains. save_path File path where the plot is saved. include_warmup Include the warmup samples in the trace plot. """ # NOTE: Docstring duplications # Multiple arguments in this docstring are shared with other plotting functions. sns.set_theme(style=style) plot_df = _setup_plot_df( results, params, param_indices, chain_indices, max_chains, include_warmup ) plot_df = ( plot_df.drop(["param_index", "param"], axis=1) .pivot( index=["chain_index", "iteration"], columns="param_label", values="value" ) .reset_index() .drop(["iteration"], axis=1) ) g = sns.pairplot( data=plot_df, hue="chain_index", plot_kws={"alpha": alpha}, diag_kind=diag_kind, height=height, aspect=aspect_ratio, palette=color_palette, ) if title is not None: g.fig.suptitle(title) g.fig.subplots_adjust(top=title_spacing) save_figure(save_path=save_path)