optim_flat()#
- liesel.goose.optim_flat(model_train, params, optimizer=None, stopper=None, batch_size=None, batch_seed=None, save_position_history=True, model_validation=None, restore_best_position=True, prune_history=True, validate_log_prob_decomposition=True, progress_bar=True, progress_n_updates=20, track_keys=None)[source]#
Optimize the parameters of a Liesel
Model.Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. See below for details.
Warning
This function is experimental. The API may change more quickly than in other parts of the library. Check your results carefully. If you encounter puzzling results, try to disable batching.
Params#
- model_train
The Liesel model to optimize.
- params
List of parameter names to optimize. All other parameters of the model are held fixed.
- optimizer
An optimizer from the
optaxlibrary. IfNone,optax.adam(learning_rate=1e-2)is used.- stopper
A
Stopperthat carries information about the maximum number of iterations and early stopping.- batch_size
The batch size. If
None, batching is disabled and each optimization step uses the full model log probability. In this case, the result storesn_train == n_validation == 1because no observation count is needed for likelihood rescaling.- batch-seed
Batches are assembled randomly in each iteration. This is the seed used for shuffling in this step.
- save_position_history
If
True, the position history is saved to the results object.- model_validation
If supplied, this model serves as a validation model, which means that early stopping is based on the validation loss evaluated using this model. If
None, the training model is also used as the validation model, so training and validation losses are identical.- restore_best_position
If
True, the position with the lowest loss within the patience defined by the suppliedStopperis restored as the final position. IfFalse, the last iteration’s position is used.- prune_history
If
True, the history is pruned to the length of the final iteration. This means, the history can be shorter than the maximum number of iterations defined by the suppliedStopper. IfFalse, unused history entries are set tojax.numpy.nanif optimization stops early.- validate_log_prob_decomposition
Whether to check that the model log probability is equal to the sum of the model log likelihood and model log prior before optimization starts. Disable this only for models whose log probability intentionally cannot be decomposed in this way.
- progress_bar
Whether to use a progress bar.
- progress_n_updates
How many times to update the progress bar in total.
- track_keys
List of position keys to track and include in the history.
- rtype:
- returns:
A dataclass of type
OptimResult, giving access to the results.
See also
history_to_dfA helper function to turn the
OptimResult.historyinto apandas.DataFrame- nice for quickly plotting results.
Notes
If
batch_sizeisNone, batching is disabled. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. This means that this function assumes that, for allVarobjects in your model, it is valid to index their values like this:var_object.value[batch_indices, ...]
The batching functionality also assumes that all objects that should be batched are included as
Varobjects withVar.observedset toTrue. With batching enabled, the training loss rescales the batched log likelihood byn_train / batch_size. The validation loss rescales the validation log likelihood byn_train / n_validationwhen a separate validation model is supplied.Examples
We show a minimal example. First, import
tfd.>>> import tensorflow_probability.substrates.jax.distributions as tfd
Next, generate some data.
>>> key = jax.random.PRNGKey(42) >>> key, subkey = jax.random.split(key) >>> x = jax.random.normal(key, (100,)) >>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))
Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1.
>>> coef = lsl.Var.new_param(jnp.zeros(2), name="coef") >>> xvar = lsl.Var.new_obs(jnp.c_[jnp.ones_like(x), x], name="x") >>> mu = Var.new_calc(jnp.dot, xvar, coef, name="mu") >>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> yvar = lsl.Var.new_obs(y, ydist, name="y") >>> model = lsl.Model([yvar])
Now, we are ready to run the optimization.
>>> stopper = gs.Stopper(max_iter=1000, patience=10, atol=0.01) >>> result = gs.optim_flat(model, params=["coef"], stopper=stopper) >>> {name: jnp.round(value, 2) for name, value in result.position.items()} {'coef': Array([0.38, 1.24], dtype=float32)}
We can now, for example, use
result.model_stateinEngineBuilder.set_initial_values()to implement a “warm start” of MCMC sampling.