optim_flat()#
- liesel.goose.optim_flat(model_train, params, optimizer=None, stopper=Stopper(max_iter=10000, patience=10, atol=0.001, rtol=1e-12), batch_size=None, batch_seed=None, save_position_history=True, model_validation=None, restore_best_position=True, prune_history=True)[source]#
Optimize the parameters of a Liesel
Model
.Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. See below for details.
Params#
- model_train
The Liesel model to optimize.
- params
List of parameter names to optimize. All other parameters of the model are held fixed.
- optimizer
An optimizer from the
optax
library. IfNone
,optax.adam(learning_rate=1e-2)
is used.- stopper
A
Stopper
that carries information about the maximum number of iterations and early stopping.- batch_size
The batch size. If
None
, the whole dataset is used for each optimization step.- batch-seed
Batches are assembled randomly in each iteration. This is the seed used for shuffling in this step.
- save_position_history
If
True
, the position history is saved to the results object.- model_validation
If supplied, this model serves as a validation model, which means that early stopping is based on the negative log likelihood evaluated using the observed data in this model. If
None
, no early stopping is conducted.- restore_best_position
If
True
, the position with the lowest loss within the patience defined by the suppliedStopper
is restored as the final postion. IfFalse
, the last iteration’s position is used.- prune_history
If
True
, the history is pruned to the length of the final iteration. This means, the history can be shorter than the maximum number of iterations defined by the suppliedStopper
. IfFalse
, unused history entries are set tojax.numpy.nan
if optimization stops early.
- rtype:
- returns:
A dataclass of type
OptimResult
, giving access to the results.
See also
history_to_df
A helper function to turn the
OptimResult.history
into apandas.DataFrame
- nice for quickly plotting results.
Notes
If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. This means that this function assumes that, for all
Var
objects in your model, it is valid to index their values like this:var_object.value[batch_indices, ...]
The batching functionality also assumes that all objects that should be batched are included as
Var
objects withVar.observed
set toTrue
.Examples
We show a minimal example. First, import
tfd
.>>> import tensorflow_probability.substrates.jax.distributions as tfd
Next, generate some data.
>>> key = jax.random.PRNGKey(42) >>> key, subkey = jax.random.split(key) >>> x = jax.random.normal(key, (100,)) >>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))
Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1.
>>> coef = lsl.param(jnp.zeros(2), name="coef") >>> xvar = lsl.obs(jnp.c_[jnp.ones_like(x), x], name="x") >>> mu = lsl.Var(lsl.Calc(jnp.dot, xvar, coef), name="mu") >>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> yvar = lsl.obs(y, ydist, name="y") >>> model = lsl.GraphBuilder().add(yvar).build_model()
Now, we are ready to run the optimization.
>>> result = gs.optim_flat(model, params=["coef"]) >>> {name: jnp.round(value, 2) for name, value in result.position.items()} {'coef': Array([0.52, 1.29], dtype=float32)}
We can now, for example, use
result.model_state
inEngineBuilder.set_initial_values()
to implement a “warm start” of MCMC sampling.