optim_flat()#
- liesel.goose.optim_flat(model_train, params, optimizer=None, stopper=None, batch_size=None, batch_seed=None, save_position_history=True, model_validation=None, restore_best_position=True, prune_history=True, progress_bar=True)[source]#
Optimize the parameters of a Liesel
Model
.Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. See below for details.
Warning
This function is experimental. The API may change more quickly than in other parts of the library. Check your results carefully. If you encounter puzzling results, try to disable batching.
Params#
- model_train
The Liesel model to optimize.
- params
List of parameter names to optimize. All other parameters of the model are held fixed.
- optimizer
An optimizer from the
optax
library. IfNone
,optax.adam(learning_rate=1e-2)
is used.- stopper
A
Stopper
that carries information about the maximum number of iterations and early stopping.- batch_size
The batch size. If
None
, the whole dataset is used for each optimization step.- batch-seed
Batches are assembled randomly in each iteration. This is the seed used for shuffling in this step.
- save_position_history
If
True
, the position history is saved to the results object.- model_validation
If supplied, this model serves as a validation model, which means that early stopping is based on the negative log likelihood evaluated using the observed data in this model. If
None
, no early stopping is conducted.- restore_best_position
If
True
, the position with the lowest loss within the patience defined by the suppliedStopper
is restored as the final postion. IfFalse
, the last iteration’s position is used.- prune_history
If
True
, the history is pruned to the length of the final iteration. This means, the history can be shorter than the maximum number of iterations defined by the suppliedStopper
. IfFalse
, unused history entries are set tojax.numpy.nan
if optimization stops early.- progress_bar
Whether to use a progress bar.
- rtype:
- returns:
A dataclass of type
OptimResult
, giving access to the results.
See also
history_to_df
A helper function to turn the
OptimResult.history
into apandas.DataFrame
- nice for quickly plotting results.
Notes
If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. This means that this function assumes that, for all
Var
objects in your model, it is valid to index their values like this:var_object.value[batch_indices, ...]
The batching functionality also assumes that all objects that should be batched are included as
Var
objects withVar.observed
set toTrue
.Examples
We show a minimal example. First, import
tfd
.>>> import tensorflow_probability.substrates.jax.distributions as tfd
Next, generate some data.
>>> key = jax.random.PRNGKey(42) >>> key, subkey = jax.random.split(key) >>> x = jax.random.normal(key, (100,)) >>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))
Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1.
>>> coef = lsl.param(jnp.zeros(2), name="coef") >>> xvar = lsl.obs(jnp.c_[jnp.ones_like(x), x], name="x") >>> mu = lsl.Var(lsl.Calc(jnp.dot, xvar, coef), name="mu") >>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> yvar = lsl.obs(y, ydist, name="y") >>> model = lsl.GraphBuilder().add(yvar).build_model()
Now, we are ready to run the optimization.
>>> stopper = gs.Stopper(max_iter=1000, patience=10, atol=0.01) >>> result = gs.optim_flat(model, params=["coef"], stopper=stopper) >>> {name: jnp.round(value, 2) for name, value in result.position.items()} {'coef': Array([0.52, 1.29], dtype=float32)}
We can now, for example, use
result.model_state
inEngineBuilder.set_initial_values()
to implement a “warm start” of MCMC sampling.