from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
from tqdm import tqdm
from ..model import Model, Node, Var
from .interface import LieselInterface
from .types import Array, KeyArray, ModelState, Position
def array_to_dict(
x: Array, names_prefix: str = "x", prefix_1d: bool = False
) -> dict[str, Array]:
"""Turns a 2d-array into a dict."""
if isinstance(x, float) or x.ndim == 1:
if prefix_1d:
return {f"{names_prefix}0": x}
else:
return {names_prefix: x}
elif x.ndim == 2:
return {f"{names_prefix}{i}": x[:, i] for i in range(x.shape[-1])}
else:
raise ValueError(f"x should have ndim <= 2, but it has x.ndim={x.ndim}")
[docs]
@dataclass
class OptimResult:
"""Holds the results of model optimization with :func:`.optim_flat`."""
model_state: ModelState
"""Final model state after optimization."""
position: Position
"""Position dictionary of optimized parameters with their final values."""
iteration: int
"""Iteration counter of the last iteration."""
iteration_best: int
"""Iteration counter of the iteration with lowest loss."""
history: dict[str, dict[str, Array] | Array]
"""History of loss evaluations and, if applicable, intermediate position values."""
max_iter: int
"""Maximum number of iterations."""
n_train: int
"""Number of training observations, or ``1`` if batching was disabled."""
n_validation: int
"""Number of validation observations, or ``1`` if batching was disabled."""
def _find_observed(model: Model) -> dict[str, Var | Node]:
obs = {
var_.name: jnp.array(var_.value)
for var_ in model.vars.values()
if var_.observed and not var_.weak
}
return obs
def batched_nodes(nodes: dict[str, Array], batch_indices: Array) -> dict[str, Array]:
"""Returns a subset of the model state using the given batch indices."""
return jax.tree_util.tree_map(lambda x: x[batch_indices, ...], nodes)
def _generate_batch_indices(
key: KeyArray, n: int, batch_size: int, shuffle: bool = True
) -> Array:
n_full_batches = n // batch_size
if shuffle:
indices = jax.random.permutation(key, n)
else:
indices = jnp.arange(n)
indices_subset = indices[: n_full_batches * batch_size]
list_of_batch_indices = jnp.array_split(indices_subset, n_full_batches)
return jnp.asarray(list_of_batch_indices)
def _find_sample_size(model: Model) -> int:
obs = {var_.name: var_ for var_ in model.vars.values() if var_.observed}
n_set = {int(np.array(var_.value.shape)[0, ...]) for var_ in obs.values()}
if len(n_set) > 1:
raise ValueError(
"The observed variables must have the same number of observations."
)
return n_set.pop()
def _validate_sample_size(value: int | None, name: str) -> int | None:
if value is None:
return None
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
raise ValueError(f"{name} must be a positive integer or None.")
return value
[docs]
@dataclass
class Stopper:
"""
Handles (early) stopping for :func:`.optim_flat`.
Parameters
----------
max_iter
The maximum number of optimization steps.
patience
Length of the recent loss window considered for early stopping. Early stopping\
is checked only after more than ``patience`` optimization steps have been\
evaluated. In other words, because ``i`` is zero-based, ``stop_early()`` first\
returns ``True`` no earlier than ``i == patience + 1``.
atol
The non-negative absolute tolerance for early stopping.
rtol
The non-negative relative tolerance for early stopping. The default of \
``0.0`` means that no early stopping happens based on the relative tolerance.
Notes
-----
Early stopping is based on the window of the most recent ``patience`` loss values
ending at the current zero-based iteration ``i``. Without tolerances, early stopping
happens when the oldest loss value in this window is also the best loss value in
this window. This is a rolling-window rule, not a best-so-far rule that counts the
number of iterations since the global best loss. It can therefore continue while
the recent window still contains newer improvements, even if the global best loss
was observed before the current window. A simplified pseudo-implementation is:
.. code-block:: python
def stop(patience, i, loss_history):
current_history = loss_history[: i + 1]
recent_history = current_history[-patience:]
oldest_within_patience = recent_history[0]
best_within_patience = np.min(recent_history)
return oldest_within_patience <= best_within_patience
Absolute and relative tolerance make it possible to stop even in cases when the
oldest loss within patience is *not* the best. Instead, the algorithm stops, when
the absolute *or* relative difference between the oldest loss within patience and
the best loss within patience is so small that it can be neglected.
To be clear: If either of the two conditions is met, then early stopping happens.
The relative magnitude of the difference is calculated with respect to the best
loss within patience. A simplified pseudo-implementation is:
.. code-block:: python
def stop(patience, i, loss_history, atol, rtol):
current_history = loss_history[: i + 1]
recent_history = current_history[-patience:]
oldest_within_patience = recent_history[0]
best_within_patience = np.min(recent_history)
diff = oldest_within_patience - best_within_patience
rel_diff = diff / np.abs(best_within_patience)
abs_improvement_is_neglectable = diff <= atol
rel_improvement_is_neglectable = rel_diff <= rtol
return (abs_improvement_is_neglectable | rel_improvement_is_neglectable)
"""
max_iter: int
patience: int
atol: float = 1e-3
rtol: float = 0.0
def __post_init__(self):
if self.max_iter < 1:
raise ValueError("max_iter must be at least 1.")
if self.patience < 1:
raise ValueError("patience must be at least 1.")
if self.patience > self.max_iter:
raise ValueError("patience must be less than or equal to max_iter.")
if self.atol < 0:
raise ValueError("atol must be non-negative.")
if self.rtol < 0:
raise ValueError("rtol must be non-negative.")
[docs]
def stop_early(self, i: int | Array, loss_history: Array):
p = self.patience
lower = jnp.max(jnp.array([i - p + 1, 0]))
recent_history = jax.lax.dynamic_slice(
loss_history, start_indices=(lower,), slice_sizes=(p,)
)
best_loss_in_recent = jnp.min(recent_history)
oldest_loss_in_recent = recent_history[0]
diff = oldest_loss_in_recent - best_loss_in_recent
abs_improvement_is_neglectable = diff <= self.atol
rel_diff = diff / jnp.abs(best_loss_in_recent)
rel_improvement_is_neglectable = rel_diff <= self.rtol
current_i_is_after_patience = i > p
"""
Stopping happens only if we actually went through a full patience period.
"""
stop = abs_improvement_is_neglectable | rel_improvement_is_neglectable
return stop & current_i_is_after_patience
[docs]
def stop_now(self, i: int | Array, loss_history: Array):
"""Whether optimization should stop now."""
stop_early = self.stop_early(i=i, loss_history=loss_history)
stop_max_iter = i >= (self.max_iter - 1)
return stop_early | stop_max_iter
[docs]
def continue_(self, i: int | Array, loss_history: Array):
"""Whether optimization should continue (inverse of :meth:`.stop_now`)."""
return ~self.stop_now(i=i, loss_history=loss_history)
[docs]
def which_best_in_recent_history(self, i: int, loss_history: Array):
"""
Identifies the index of the best observation in the recent loss window.
The recent loss window contains the last ``p`` entries of ``loss_history``,
looking backwards from the current zero-based iteration ``i``, where ``p`` is
the patience. This returns the best index in that recent window, not
necessarily the global best index in the full loss history.
"""
p = self.patience
recent_history = jax.lax.dynamic_slice(
loss_history, start_indices=(i - p + 1,), slice_sizes=(p,)
)
imin = jnp.argmin(recent_history)
return i - self.patience + imin + 1
def _validate_log_prob_decomposition(
interface: LieselInterface, position: Position, state: ModelState
) -> bool:
updated_state = interface.update_state(position, state)
log_prob = updated_state["_model_log_prob"].value
log_lik = updated_state["_model_log_lik"].value
log_prior = updated_state["_model_log_prior"].value
if not jnp.allclose(log_prob, log_lik + log_prior):
raise ValueError(
f"You model's {log_prob=} cannot correctly be decomposed into the"
f" {log_prior=} and the {log_lik=}. Check whether your observed variables"
" have the attribute observed=True and whether your parameter variables"
" have the attribute parameter=True."
)
return True
[docs]
def optim_flat(
model_train: Model,
params: Sequence[str],
optimizer: optax.GradientTransformation | None = None,
stopper: Stopper | None = None,
batch_size: int | None = None,
batch_seed: int | None = None,
save_position_history: bool = True,
model_validation: Model | None = None,
restore_best_position: bool = True,
prune_history: bool = True,
validate_log_prob_decomposition: bool = True,
progress_bar: bool = True,
progress_n_updates: int = 20,
track_keys: list[str] | None = None,
scale_loss: bool = False,
n_train: int | None = None,
n_validation: int | None = None,
auto_n_obs: bool = False,
) -> OptimResult:
"""
Optimize the parameters of a Liesel :class:`.Model`.
Approximates maximum a posteriori (MAP) parameter estimates by minimizing the
negative log posterior probability of the model. If you use batching, be aware that
the batching functionality implemented here assumes a "flat" model structure.
See below for details.
.. warning::
This function is experimental.
The API may change more quickly than in other parts of the library.
Check your results carefully. If you encounter puzzling results, try to disable
batching.
Params
------
model_train
The Liesel model to optimize.
params
List of parameter names to optimize. All other parameters of the model are held\
fixed.
optimizer
An optimizer from the ``optax`` library. If ``None`` , \
``optax.adam(learning_rate=1e-2)`` is used.
stopper
A :class:`.Stopper` that carries information about the maximum number of\
iterations and early stopping.
batch_size
The batch size. If ``None``, batching is disabled and each optimization step\
uses the full model log probability. In this case, the result stores\
``n_train == n_validation == 1`` because no observation count is needed for\
likelihood rescaling, unless explicit sample sizes are supplied or\
``auto_n_obs=True``.
batch-seed
Batches are assembled randomly in each iteration. This is the seed used for \
shuffling in this step.
save_position_history
If ``True``, the position history is saved to the results object.
model_validation
If supplied, this model serves as a validation model, which means that early\
stopping is based on the validation loss evaluated using this model. If\
``None``, the training model is also used as the validation model, so training\
and validation losses are identical.
restore_best_position
If ``True``, the position with the lowest loss within the patience defined\
by the supplied :class:`.Stopper` is restored as the final position. If \
``False``, the last iteration's position is used.
prune_history
If ``True``, the history is pruned to the length of the final iteration. This\
means, the history can be shorter than the maximum number of iterations defined\
by the supplied :class:`.Stopper`. If ``False``, unused history entries are set\
to ``jax.numpy.nan`` if optimization stops early.
validate_log_prob_decomposition
Whether to check that the model log probability is equal to the sum of the\
model log likelihood and model log prior before optimization starts. Disable\
this only for models whose log probability intentionally cannot be decomposed\
in this way.
progress_bar
Whether to use a progress bar.
progress_n_updates
How many times to update the progress bar in total.
track_keys
List of position keys to track and include in the history.
scale_loss
Whether to divide the training and validation losses by ``n_train``.
n_train
Number of training observations used for likelihood rescaling when batching
is disabled. If ``None``, defaults to ``1`` unless ``auto_n_obs=True``.
n_validation
Number of validation observations used for likelihood rescaling when
batching is disabled. If ``None`` and no separate validation model is
supplied, defaults to ``n_train``. When a separate validation model is
supplied, pass both ``n_train`` and ``n_validation`` or set
``auto_n_obs=True``.
auto_n_obs
Whether to discover ``n_train`` and ``n_validation`` from the leading
dimension of observed variables when batching is disabled. Batched
optimization always discovers sample sizes because they are needed to build
batches.
Returns
-------
A dataclass of type :class:`.OptimResult`, giving access to the results.
See Also
--------
.history_to_df : A helper function to turn the :attr:`.OptimResult.history` into
a ``pandas.DataFrame`` - nice for quickly plotting results.
Notes
-----
If ``batch_size`` is ``None``, batching is disabled. If you use batching, be aware
that the batching functionality implemented here assumes a "flat" model structure.
This means
that this function assumes that, for all :class:`.Var` objects in your model, it
is valid to index their values like this::
var_object.value[batch_indices, ...]
The batching functionality also assumes that all objects that should be batched
are included as :class:`.Var` objects with ``Var.observed`` set to ``True``.
With batching enabled, the training loss rescales the batched log likelihood by
``n_train / batch_size``. The validation loss rescales the validation log likelihood
by ``n_train / n_validation`` when a separate validation model is supplied.
With batching disabled, ``n_train`` and ``n_validation`` default to ``1``. When a
separate validation model is supplied, you must either pass both values explicitly
or set ``auto_n_obs=True``.
Examples
--------
We show a minimal example. First, import ``tfd``.
>>> import tensorflow_probability.substrates.jax.distributions as tfd
Next, generate some data.
>>> key = jax.random.PRNGKey(42)
>>> key, subkey = jax.random.split(key)
>>> x = jax.random.normal(key, (100,))
>>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))
Next, set up a linear model. For simplicity, we assume the scale to be fixed to the
true value of 1.
>>> coef = lsl.Var.new_param(jnp.zeros(2), name="coef")
>>> xvar = lsl.Var.new_obs(jnp.c_[jnp.ones_like(x), x], name="x")
>>> mu = Var.new_calc(jnp.dot, xvar, coef, name="mu")
>>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> yvar = lsl.Var.new_obs(y, ydist, name="y")
>>> model = lsl.Model([yvar])
Now, we are ready to run the optimization.
>>> stopper = gs.Stopper(max_iter=1000, patience=10, atol=0.01)
>>> result = gs.optim_flat(model, params=["coef"], stopper=stopper)
>>> {name: jnp.round(value, 2) for name, value in result.position.items()}
{'coef': Array([0.38, 1.24], dtype=float32)}
We can now, for example, use ``result.model_state`` in
:meth:`.EngineBuilder.set_initial_values` to implement a "warm start" of MCMC
sampling.
"""
track_keys = track_keys if track_keys is not None else []
# ---------------------------------------------------------------------------------
# Validation input
if restore_best_position:
assert save_position_history, (
"Cannot restore best position if history is not saved."
)
# ---------------------------------------------------------------------------------
# Pre-process inputs
batch_seed = (
batch_seed if batch_seed is not None else np.random.randint(low=1, high=1000)
)
n_train = _validate_sample_size(n_train, "n_train")
n_validation = _validate_sample_size(n_validation, "n_validation")
if stopper is None:
stopper = Stopper(max_iter=10_000, patience=10)
user_patience = stopper.patience
if model_validation is None:
model_validation = model_train
if optimizer is None:
optimizer = optax.adam(learning_rate=1e-2)
do_batching = batch_size is not None
if do_batching:
if n_train is not None or n_validation is not None:
raise ValueError(
"n_train and n_validation can only be supplied when batch_size is None."
)
shuffle_batch_indices = True
observed = _find_observed(model_train)
n_train = _find_sample_size(model_train)
n_validation = _find_sample_size(model_validation)
batch_size = batch_size if batch_size is not None else n_train
else:
shuffle_batch_indices = False
# not because there are no observed, but because we don't need to update
# observed with their batches
observed = {}
if auto_n_obs:
n_train = n_train or _find_sample_size(model_train)
n_validation = n_validation or _find_sample_size(model_validation)
elif model_validation is not model_train and (
n_train is None or n_validation is None
):
raise ValueError(
"When batch_size is None and model_validation is a separate model, "
"pass both n_train and n_validation or set auto_n_obs=True."
)
else:
n_train = n_train or 1
n_validation = n_validation or n_train
batch_size = n_train
interface_train = LieselInterface(model_train)
position = interface_train.extract_position(params, model_train.state)
track = interface_train.extract_position(track_keys, model_train.state)
interface_train._model.auto_update = False
interface_validation = LieselInterface(model_validation)
interface_validation._model.auto_update = False
# ---------------------------------------------------------------------------------
# Validate model log prob decomposition
if validate_log_prob_decomposition:
_validate_log_prob_decomposition(
interface_train, position=position, state=model_train.state
)
_validate_log_prob_decomposition(
interface_validation, position=position, state=model_validation.state
)
# ---------------------------------------------------------------------------------
# Define loss function(s)
likelihood_scalar = n_train / batch_size
likelihood_scalar_validation = n_train / n_validation
def _batched_neg_log_prob(
position: Position, model_state: ModelState, batch_indices: Array | None = None
):
batched_observed = batched_nodes(observed, batch_indices)
position = position | batched_observed # type: ignore
updated_state = interface_train.update_state(position, model_state)
if not do_batching:
log_prob = updated_state["_model_log_prob"].value
else:
log_lik = likelihood_scalar * updated_state["_model_log_lik"].value
log_prior = updated_state["_model_log_prior"].value
log_prob = log_lik + log_prior
nlp = -log_prob
if scale_loss:
nlp = nlp / n_train
return nlp
def _neg_log_prob_train(position: Position, model_state: ModelState):
updated_state = interface_train.update_state(position, model_state)
nlp = -updated_state["_model_log_prob"].value
if scale_loss:
nlp = nlp / n_train
return nlp
def _neg_log_prob_validation(position: Position, model_state: ModelState):
updated_state = interface_validation.update_state(position, model_state)
log_lik = likelihood_scalar_validation * updated_state["_model_log_lik"].value
log_prior = updated_state["_model_log_prior"].value
log_prob = log_lik + log_prior
nlp = -log_prob
if scale_loss:
nlp = nlp / n_train
return nlp
if model_validation is model_train:
_neg_log_prob_validation = _neg_log_prob_train
neg_log_prob_grad = jax.grad(_batched_neg_log_prob, argnums=0)
# ---------------------------------------------------------------------------------
# Initialize history
history: dict[str, Any] = dict()
history["loss_train"] = jnp.zeros(shape=stopper.max_iter)
history["loss_validation"] = jnp.zeros(shape=stopper.max_iter)
if save_position_history:
history["position"] = {
name: jnp.zeros((stopper.max_iter,) + jnp.shape(value))
for name, value in position.items()
}
history["position"] = jax.tree.map(
lambda d, pos: d.at[0].set(pos), history["position"], position
)
history["tracked"] = {
name: jnp.zeros((stopper.max_iter,) + jnp.shape(value))
for name, value in track.items()
}
history["tracked"] = jax.tree.map(
lambda d, pos: d.at[0].set(pos), history["tracked"], track
)
else:
history["position"] = None
loss_train_start = _neg_log_prob_train(
position=position, model_state=model_train.state
)
loss_validation_start = _neg_log_prob_validation(
position=position, model_state=model_validation.state
)
history["loss_train"] = history["loss_train"].at[0].set(loss_train_start)
history["loss_validation"] = (
history["loss_validation"].at[0].set(loss_validation_start)
)
# ---------------------------------------------------------------------------------
# Initialize while loop carry dictionary
init_val: dict[str, Any] = dict()
init_val["while_i"] = 0
init_val["history"] = history
init_val["position"] = position
init_val["tracked"] = track
init_val["opt_state"] = optimizer.init(position)
init_val["current_loss_train"] = history["loss_train"][0]
init_val["current_loss_validation"] = history["loss_validation"][0]
init_val["key"] = jax.random.PRNGKey(batch_seed)
init_val["model_state_train"] = model_train.state
init_val["model_state_validation"] = model_validation.state
# ---------------------------------------------------------------------------------
# Initialize while loop carry dictionary
if progress_bar:
if stopper.max_iter > progress_n_updates:
print_rate = int(stopper.max_iter / progress_n_updates)
else:
print_rate = 1
print_remainder = stopper.max_iter % print_rate
progress_bar_inst = tqdm(
total=stopper.max_iter,
desc=(
f"Training loss: {loss_train_start:.3f}, Validation loss:"
f" {loss_validation_start:.3f}"
),
position=0,
)
def tqdm_update(val, update=print_rate):
loss_train = val["current_loss_train"]
loss_validation = val["current_loss_validation"]
desc = (
f"Training loss: {loss_train:.3f}, Validation loss:"
f" {loss_validation:.3f}"
)
progress_bar_inst.update(update)
progress_bar_inst.set_description(desc)
return val
def tqdm_callback(val):
iter_num = val["while_i"] + 1
_ = jax.lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0),
lambda _: jax.experimental.io_callback(tqdm_update, val, val),
lambda _: val,
operand=None,
)
else:
def tqdm_callback(val):
return None
# ---------------------------------------------------------------------------------
# Define while loop body
def body_fun(val: dict):
_, subkey = jax.random.split(val["key"])
batches = _generate_batch_indices(
key=subkey, n=n_train, batch_size=batch_size, shuffle=shuffle_batch_indices
)
# -----------------------------------------------------------------------------
# Loop over batches
def _fori_body(i, val):
batch = batches[i]
pos = val["position"]
grad = neg_log_prob_grad(
pos, batch_indices=batch, model_state=val["model_state_train"]
)
updates, opt_state = optimizer.update(grad, val["opt_state"], params=pos)
val["position"] = optax.apply_updates(pos, updates)
updated_state = interface_train.update_state(
val["position"], val["model_state_train"]
)
val["tracked"] = interface_train.extract_position(track_keys, updated_state)
val["opt_state"] = opt_state
return val
val = jax.lax.fori_loop(
body_fun=_fori_body, init_val=val, lower=0, upper=len(batches)
)
# -----------------------------------------------------------------------------
# Save values and increase counter
val["while_i"] += 1
loss_train = _neg_log_prob_train(
val["position"], model_state=val["model_state_train"]
)
val["history"]["loss_train"] = (
val["history"]["loss_train"].at[val["while_i"]].set(loss_train)
)
loss_validation = _neg_log_prob_validation(
val["position"], model_state=val["model_state_validation"]
)
val["history"]["loss_validation"] = (
val["history"]["loss_validation"].at[val["while_i"]].set(loss_validation)
)
val["current_loss_train"] = loss_train
val["current_loss_validation"] = loss_validation
if save_position_history:
pos_hist = val["history"]["position"]
val["history"]["position"] = jax.tree.map(
lambda d, pos: d.at[val["while_i"]].set(pos), pos_hist, val["position"]
)
pos_hist = val["history"]["tracked"]
val["history"]["tracked"] = jax.tree.map(
lambda d, pos: d.at[val["while_i"]].set(pos), pos_hist, val["tracked"]
)
if progress_bar:
tqdm_callback(val)
return val
# ---------------------------------------------------------------------------------
# Run while loop
val = jax.lax.while_loop(
cond_fun=lambda val: stopper.continue_(
val["while_i"], val["history"]["loss_validation"]
),
body_fun=body_fun,
init_val=init_val,
)
if progress_bar:
print_remainder = int((val["while_i"] + 1) % print_rate)
tqdm_update(val, print_remainder)
progress_bar_inst.close()
max_iter = val["while_i"]
# ---------------------------------------------------------------------------------
# Set final position and model state
stopper.patience = user_patience
ibest = stopper.which_best_in_recent_history(
i=max_iter, loss_history=val["history"]["loss_validation"]
)
if restore_best_position:
final_position: Position = {
name: pos[ibest] for name, pos in val["history"]["position"].items()
} # type: ignore
else:
final_position = val["position"]
final_state = interface_train.update_state(final_position, model_train.state)
# ---------------------------------------------------------------------------------
# Set unused values in history to nan
val["history"]["loss_train"] = (
val["history"]["loss_train"].at[(max_iter + 1) :].set(jnp.nan)
)
val["history"]["loss_validation"] = (
val["history"]["loss_validation"].at[(max_iter + 1) :].set(jnp.nan)
)
if save_position_history:
for name, value in val["history"]["position"].items():
val["history"]["position"][name] = value.at[(max_iter + 1) :, ...].set(
jnp.nan
)
for name, value in val["history"]["tracked"].items():
val["history"]["tracked"][name] = value.at[(max_iter + 1) :, ...].set(
jnp.nan
)
# ---------------------------------------------------------------------------------
# Remove unused values in history, if applicable
if prune_history:
val["history"]["loss_train"] = val["history"]["loss_train"][: (max_iter + 1)]
val["history"]["loss_validation"] = val["history"]["loss_validation"][
: (max_iter + 1)
]
if save_position_history:
for name, value in val["history"]["position"].items():
val["history"]["position"][name] = value[: (max_iter + 1), ...]
for name, value in val["history"]["tracked"].items():
val["history"]["tracked"][name] = value[: (max_iter + 1), ...]
# ---------------------------------------------------------------------------------
# Initialize results object and return
result = OptimResult(
model_state=final_state,
position=final_position,
iteration=max_iter,
iteration_best=ibest,
history=val["history"],
max_iter=stopper.max_iter,
n_train=n_train,
n_validation=n_validation,
)
return result
[docs]
def history_to_df(history: dict[str, Array]) -> pd.DataFrame:
"""
Turns a :attr:`.OptimResult.history` dictionary into a ``pandas.DataFrame``.
"""
data: dict[str, Array] = dict()
position_history = history.get("position", None)
tracked_history = history.get("tracked", None)
for name, value in history.items():
if name in ["position", "tracked"]:
continue
data |= array_to_dict(value, names_prefix=name)
if position_history is not None:
for name, value in position_history.items():
data |= array_to_dict(value.squeeze(), names_prefix=name)
if tracked_history is not None:
for name, value in tracked_history.items():
data |= array_to_dict(value.squeeze(), names_prefix=name)
df = pd.DataFrame(data)
df["iteration"] = np.arange(value.shape[0])
return df.astype(float)