from import Sequence
from dataclasses import dataclass
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd

from ..model import Model, Node, Var
from .interface import LieselInterface
from .types import Array, KeyArray, ModelState, Position

def array_to_dict(
    x: Array, names_prefix: str = "x", prefix_1d: bool = False
) -> dict[str, Array]:
    """Turns a 2d-array into a dict."""

    if isinstance(x, float) or x.ndim == 1:
        if prefix_1d:
            return {f"{names_prefix}0": x}
            return {names_prefix: x}
    elif x.ndim == 2:
        return {f"{names_prefix}{i}": x[:, i] for i in range(x.shape[-1])}
        raise ValueError(f"x should have ndim <= 2, but it has x.ndim={x.ndim}")

[docs] @dataclass class OptimResult: """Holds the results of model optimization with :func:`.optim_flat`.""" model_state: ModelState """Final model state after optimization.""" position: Position """Position dictionary of optimized parameters with their final values.""" iteration: int """Iteration counter of the last iteration.""" iteration_best: int """Iteration counter of the iteration with lowest loss.""" history: dict[str, dict[str, Array] | Array] """History of loss evaluations and, if applicable, intermediate position values.""" max_iter: int """Maximum number of iterations.""" n_train: int """Number of training observations.""" n_validation: int """Number of test observations."""
def _find_observed(model: Model) -> dict[str, Var | Node]: obs = { jnp.array(var_.value) for var_ in model.vars.values() if var_.observed } return obs def batched_nodes(nodes: dict[str, Array], batch_indices: Array) -> dict[str, Array]: """Returns a subset of the model state using the given batch indices.""" return jax.tree_util.tree_map(lambda x: x[batch_indices, ...], nodes) def _generate_batch_indices(key: KeyArray, n: int, batch_size: int) -> Array: n_full_batches = n // batch_size shuffled_indices = jax.random.permutation(key, n) shuffled_indices_subset = shuffled_indices[0 : n_full_batches * batch_size] list_of_batch_indices = jnp.array_split(shuffled_indices_subset, n_full_batches) return jnp.asarray(list_of_batch_indices) def _find_sample_size(model: Model) -> int: obs = { var_ for var_ in model.vars.values() if var_.observed} n_set = {int(np.array(var_.value.shape)[0, ...]) for var_ in obs.values()} if len(n_set) > 1: raise ValueError( "The observed variables must have the same number of observations." ) return n_set.pop()
[docs] @dataclass class Stopper: """ Handles (early) stopping for :func:`.optim_flat`. Parameters ---------- max_iter The maximum number of optimization steps. patience Early stopping happens only, if there was no improvement for the number of\ patience iterations. atol The absolute tolerance for early stopping. If the change in the negative log\ probability (compared to the best value observed within the patience period)\ is smaller than this value, optimization stops early. rtol The relative tolerance for early stopping. If the relative absolute change in \ the negative log probability is smaller than this value, the optimization stops. """ max_iter: int patience: int atol: float = 1e-3 rtol: float = 1e-12
[docs] def stop_early(self, i: int | Array, loss_history: Array): """ Includes loss at iterations *before* i, but excluding i itself. """ p = self.patience lower = jnp.max(jnp.array([(i - 1) - p, 0])) recent_history = jax.lax.dynamic_slice( loss_history, start_indices=(lower,), slice_sizes=(p,) ) best_loss_in_recent = jnp.min(recent_history) current_loss = loss_history[i] change = current_loss - best_loss_in_recent """ If current_loss is better than best_loss_in_recent, this is negative. If current_loss is worse, this is positive. """ rel_change = jnp.abs(jnp.abs(change) / best_loss_in_recent) no_improvement = change > self.atol """ If the current loss has not improved upon the best loss in the patience period, we always want to stop. However, we actually allow for slightly worse losses, defined by the absolute tolerance here. """ no_rel_change = ~no_improvement & (rel_change < self.rtol) """ Let's say the current value *does* improve upon the best value within patience, such that no_improvement=False. In this case, if the improvement is very small compared to the best observed loss in the patience period, we may still want to stop. """ return (no_improvement | no_rel_change) & (i > p)
[docs] def stop_now(self, i: int | Array, loss_history: Array): """Whether optimization should stop now.""" stop_early = self.stop_early(i=i, loss_history=loss_history) stop_max_iter = i >= self.max_iter return stop_early | stop_max_iter
[docs] def continue_(self, i: int | Array, loss_history: Array): """Whether optimization should continue (inverse of :meth:`.stop_now`).""" return ~self.stop_now(i=i, loss_history=loss_history)
[docs] def which_best_in_recent_history(self, i: int, loss_history: Array): """ Identifies the index of the best observation in recent history. Recent history includes the last ``p`` iterations looking backwards from the current iteration `ì``., where ``p`` is the patience. """ p = self.patience recent_history = jax.lax.dynamic_slice( loss_history, start_indices=(i - p,), slice_sizes=(p,) ) imin = jnp.argmin(recent_history) return i - self.patience + imin
def _validate_log_prob_decomposition( interface: LieselInterface, position: Position, state: ModelState ) -> bool: updated_state = interface.update_state(position, state) log_prob = updated_state["_model_log_prob"].value log_lik = updated_state["_model_log_lik"].value log_prior = updated_state["_model_log_prior"].value if not jnp.allclose(log_prob, log_lik + log_prior): raise ValueError( f"You model's {log_prob=} cannot correctly be decomposed into the" f" {log_prior=} and the {log_lik=}. Check whether your observed variables" " have the attribute observed=True and whether your parameter variables" " have the attribute parameter=True." ) return True
[docs] def optim_flat( model_train: Model, params: Sequence[str], optimizer: optax.GradientTransformation | None = None, stopper: Stopper = Stopper(max_iter=10_000, patience=10), batch_size: int | None = None, batch_seed: int | None = None, save_position_history: bool = True, model_validation: Model | None = None, restore_best_position: bool = True, prune_history: bool = True, ) -> OptimResult: """ Optimize the parameters of a Liesel :class:`.Model`. Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a "flat" model structure. See below for details. Params ------ model_train The Liesel model to optimize. params List of parameter names to optimize. All other parameters of the model are held\ fixed. optimizer An optimizer from the ``optax`` library. If ``None`` , \ ``optax.adam(learning_rate=1e-2)`` is used. stopper A :class:`.Stopper` that carries information about the maximum number of\ iterations and early stopping. batch_size The batch size. If ``None``, the whole dataset\ is used for each optimization step. batch-seed Batches are assembled randomly in each iteration. This is the seed used for \ shuffling in this step. save_position_history If ``True``, the position history is saved to the results object. model_validation If supplied, this model serves as a validation model, which means that early\ stopping is based on the negative log likelihood evaluated using the observed\ data in this model. If ``None``, no early stopping is conducted. restore_best_position If ``True``, the position with the lowest loss within the patience defined\ by the supplied :class:`.Stopper` is restored as the final postion. If \ ``False``, the last iteration's position is used. prune_history If ``True``, the history is pruned to the length of the final iteration. This\ means, the history can be shorter than the maximum number of iterations defined\ by the supplied :class:`.Stopper`. If ``False``, unused history entries are set\ to ``jax.numpy.nan`` if optimization stops early. Returns ------- A dataclass of type :class:`.OptimResult`, giving access to the results. See Also -------- .history_to_df : A helper function to turn the :attr:`.OptimResult.history` into a ``pandas.DataFrame`` - nice for quickly plotting results. Notes ----- If you use batching, be aware that the batching functionality implemented here assumes a "flat" model structure. This means that this function assumes that, for all :class:`.Var` objects in your model, it is valid to index their values like this:: var_object.value[batch_indices, ...] The batching functionality also assumes that all objects that should be batched are included as :class:`.Var` objects with ``Var.observed`` set to ``True``. Examples -------- We show a minimal example. First, import ``tfd``. >>> import tensorflow_probability.substrates.jax.distributions as tfd Next, generate some data. >>> key = jax.random.PRNGKey(42) >>> key, subkey = jax.random.split(key) >>> x = jax.random.normal(key, (100,)) >>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,)) Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1. >>> coef = lsl.param(jnp.zeros(2), name="coef") >>> xvar = lsl.obs(jnp.c_[jnp.ones_like(x), x], name="x") >>> mu = lsl.Var(lsl.Calc(, xvar, coef), name="mu") >>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> yvar = lsl.obs(y, ydist, name="y") >>> model = lsl.GraphBuilder().add(yvar).build_model() Now, we are ready to run the optimization. >>> result = gs.optim_flat(model, params=["coef"]) >>> {name: jnp.round(value, 2) for name, value in result.position.items()} {'coef': Array([0.52, 1.29], dtype=float32)} We can now, for example, use ``result.model_state`` in :meth:`.EngineBuilder.set_initial_values` to implement a "warm start" of MCMC sampling. """ # --------------------------------------------------------------------------------- # Validation input if restore_best_position: assert ( save_position_history ), "Cannot restore best position if history is not saved." # --------------------------------------------------------------------------------- # Pre-process inputs batch_seed = ( batch_seed if batch_seed is not None else np.random.randint(low=1, high=1000) ) user_patience = stopper.patience if model_validation is None: model_validation = model_train stopper.patience = stopper.max_iter if optimizer is None: optimizer = optax.adam(learning_rate=1e-2) n_train = _find_sample_size(model_train) n_validation = _find_sample_size(model_validation) observed = _find_observed(model_train) batch_size = batch_size if batch_size is not None else n_train interface_train = LieselInterface(model_train) position = interface_train.extract_position(params, model_train.state) interface_train._model.auto_update = False interface_validation = LieselInterface(model_validation) interface_validation._model.auto_update = False # --------------------------------------------------------------------------------- # Validate model log prob decomposition _validate_log_prob_decomposition( interface_train, position=position, state=model_train.state ) _validate_log_prob_decomposition( interface_validation, position=position, state=model_validation.state ) # --------------------------------------------------------------------------------- # Define loss function(s) likelihood_scalar = n_train / batch_size likelihood_scalar_validation = n_train / n_validation def _batched_neg_log_prob(position: Position, batch_indices: Array | None = None): if batch_indices is not None: batched_observed = batched_nodes(observed, batch_indices) position = position | batched_observed # type: ignore updated_state = interface_train.update_state(position, model_train.state) log_lik = likelihood_scalar * updated_state["_model_log_lik"].value log_prior = updated_state["_model_log_prior"].value log_prob = log_lik + log_prior return -log_prob def _neg_log_prob_train(position: Position): updated_state = interface_validation.update_state(position, model_train.state) return -updated_state["_model_log_prob"].value def _neg_log_prob_validation(position: Position): updated_state = interface_validation.update_state( position, model_validation.state ) log_lik = likelihood_scalar_validation * updated_state["_model_log_lik"].value log_prior = updated_state["_model_log_prior"].value log_prob = log_lik + log_prior return -log_prob neg_log_prob_grad = jax.grad(_batched_neg_log_prob, argnums=0) # --------------------------------------------------------------------------------- # Initialize history history: dict[str, Any] = dict() history["loss_train"] = jnp.zeros(shape=stopper.max_iter) history["loss_validation"] = jnp.zeros(shape=stopper.max_iter) if save_position_history: history["position"] = { name: jnp.zeros((stopper.max_iter,) + jnp.shape(value)) for name, value in position.items() } else: history["position"] = None loss_train_start = _neg_log_prob_train(position=position) loss_validation_start = _neg_log_prob_validation(position=position) history["loss_train"] = history["loss_train"].at[0].set(loss_train_start) history["loss_validation"] = ( history["loss_validation"].at[0].set(loss_validation_start) ) # --------------------------------------------------------------------------------- # Initialize while loop carry dictionary init_val: dict[str, Any] = dict() init_val["while_i"] = 0 init_val["history"] = history init_val["position"] = position init_val["opt_state"] = optimizer.init(position) init_val["key"] = jax.random.PRNGKey(batch_seed) # --------------------------------------------------------------------------------- # Define while loop body def body_fun(val: dict): _, subkey = jax.random.split(val["key"]) batches = _generate_batch_indices(key=subkey, n=n_train, batch_size=batch_size) # ----------------------------------------------------------------------------- # Loop over batches def _fori_body(i, val): batch = batches[i] pos = val["position"] grad = neg_log_prob_grad(pos, batch_indices=batch) updates, opt_state = optimizer.update(grad, val["opt_state"]) val["position"] = optax.apply_updates(pos, updates) val["opt_state"] = opt_state return val val = jax.lax.fori_loop( body_fun=_fori_body, init_val=val, lower=0, upper=len(batches) ) # ----------------------------------------------------------------------------- # Save values and increase counter loss_train = _neg_log_prob_train(val["position"]) val["history"]["loss_train"] = ( val["history"]["loss_train"].at[val["while_i"]].set(loss_train) ) loss_validation = _neg_log_prob_validation(val["position"]) val["history"]["loss_validation"] = ( val["history"]["loss_validation"].at[val["while_i"]].set(loss_validation) ) if save_position_history: pos_hist = val["history"]["position"] val["history"]["position"] = jax.tree_map( lambda d, pos:[val["while_i"]].set(pos), pos_hist, val["position"] ) val["while_i"] += 1 return val # --------------------------------------------------------------------------------- # Run while loop val = jax.lax.while_loop( cond_fun=lambda val: stopper.continue_( jnp.clip(val["while_i"] - 1, a_min=0), val["history"]["loss_validation"] ), body_fun=body_fun, init_val=init_val, ) max_iter = val["while_i"] - 1 # --------------------------------------------------------------------------------- # Set final position and model state stopper.patience = user_patience ibest = stopper.which_best_in_recent_history( i=max_iter, loss_history=val["history"]["loss_validation"] ) if restore_best_position: final_position: Position = { name: pos[ibest] for name, pos in val["history"]["position"].items() } # type: ignore else: final_position = val["position"] final_state = interface_train.update_state(final_position, model_train.state) # --------------------------------------------------------------------------------- # Set unused values in history to nan val["history"]["loss_train"] = ( val["history"]["loss_train"].at[max_iter:].set(jnp.nan) ) val["history"]["loss_validation"] = ( val["history"]["loss_validation"].at[max_iter:].set(jnp.nan) ) if save_position_history: for name, value in val["history"]["position"].items(): val["history"]["position"][name] =[max_iter:, ...].set(jnp.nan) # --------------------------------------------------------------------------------- # Remove unused values in history, if applicable if prune_history: val["history"]["loss_train"] = val["history"]["loss_train"][:max_iter] val["history"]["loss_validation"] = val["history"]["loss_validation"][:max_iter] if save_position_history: for name, value in val["history"]["position"].items(): val["history"]["position"][name] = value[:max_iter, ...] # --------------------------------------------------------------------------------- # Initialize results object and return result = OptimResult( model_state=final_state, position=final_position, iteration=max_iter, iteration_best=ibest, history=val["history"], max_iter=stopper.max_iter, n_train=n_train, n_validation=n_validation, ) return result
[docs] def history_to_df(history: dict[str, Array]) -> pd.DataFrame: """ Turns a :attr:`.OptimResult.history` dictionary into a ``pandas.DataFrame``. """ data: dict[str, Array] = dict() position_history = history.get("position", None) for name, value in history.items(): if name == "position": continue data |= array_to_dict(value, names_prefix=name) if position_history is not None: for name, value in position_history.items(): data |= array_to_dict(value, names_prefix=name) df = pd.DataFrame(data) df["iteration"] = np.arange(value.shape[0]) return df