optim_flat()

Contents

optim_flat()#

liesel.goose.optim_flat(model_train, params, optimizer=None, stopper=None, batch_size=None, batch_seed=None, save_position_history=True, model_validation=None, restore_best_position=True, prune_history=True, validate_log_prob_decomposition=True, progress_bar=True, progress_n_updates=20, track_keys=None)[source]#

Optimize the parameters of a Liesel Model.

Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. See below for details.

Warning

This function is experimental. The API may change more quickly than in other parts of the library. Check your results carefully. If you encounter puzzling results, try to disable batching.

Params#

model_train

The Liesel model to optimize.

params

List of parameter names to optimize. All other parameters of the model are held fixed.

optimizer

An optimizer from the optax library. If None , optax.adam(learning_rate=1e-2) is used.

stopper

A Stopper that carries information about the maximum number of iterations and early stopping.

batch_size

The batch size. If None, batching is disabled and each optimization step uses the full model log probability. In this case, the result stores n_train == n_validation == 1 because no observation count is needed for likelihood rescaling.

batch-seed

Batches are assembled randomly in each iteration. This is the seed used for shuffling in this step.

save_position_history

If True, the position history is saved to the results object.

model_validation

If supplied, this model serves as a validation model, which means that early stopping is based on the validation loss evaluated using this model. If None, the training model is also used as the validation model, so training and validation losses are identical.

restore_best_position

If True, the position with the lowest loss within the patience defined by the supplied Stopper is restored as the final position. If False, the last iteration’s position is used.

prune_history

If True, the history is pruned to the length of the final iteration. This means, the history can be shorter than the maximum number of iterations defined by the supplied Stopper. If False, unused history entries are set to jax.numpy.nan if optimization stops early.

validate_log_prob_decomposition

Whether to check that the model log probability is equal to the sum of the model log likelihood and model log prior before optimization starts. Disable this only for models whose log probability intentionally cannot be decomposed in this way.

progress_bar

Whether to use a progress bar.

progress_n_updates

How many times to update the progress bar in total.

track_keys

List of position keys to track and include in the history.

rtype:

OptimResult

returns:

A dataclass of type OptimResult, giving access to the results.

See also

history_to_df

A helper function to turn the OptimResult.history into a pandas.DataFrame - nice for quickly plotting results.

Notes

If batch_size is None, batching is disabled. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. This means that this function assumes that, for all Var objects in your model, it is valid to index their values like this:

var_object.value[batch_indices, ...]

The batching functionality also assumes that all objects that should be batched are included as Var objects with Var.observed set to True. With batching enabled, the training loss rescales the batched log likelihood by n_train / batch_size. The validation loss rescales the validation log likelihood by n_train / n_validation when a separate validation model is supplied.

Examples

We show a minimal example. First, import tfd.

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Next, generate some data.

>>> key = jax.random.PRNGKey(42)
>>> key, subkey = jax.random.split(key)
>>> x = jax.random.normal(key, (100,))
>>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))

Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1.

>>> coef = lsl.Var.new_param(jnp.zeros(2), name="coef")
>>> xvar = lsl.Var.new_obs(jnp.c_[jnp.ones_like(x), x], name="x")
>>> mu = Var.new_calc(jnp.dot, xvar, coef, name="mu")
>>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> yvar = lsl.Var.new_obs(y, ydist, name="y")
>>> model = lsl.Model([yvar])

Now, we are ready to run the optimization.

>>> stopper = gs.Stopper(max_iter=1000, patience=10, atol=0.01)
>>> result = gs.optim_flat(model, params=["coef"], stopper=stopper)
>>> {name: jnp.round(value, 2) for name, value in result.position.items()}
{'coef': Array([0.38, 1.24], dtype=float32)}

We can now, for example, use result.model_state in EngineBuilder.set_initial_values() to implement a “warm start” of MCMC sampling.