optim_flat()

Contents

optim_flat()#

liesel.goose.optim_flat(model_train, params, optimizer=None, stopper=Stopper(max_iter=10000, patience=10, atol=0.001, rtol=1e-12), batch_size=None, batch_seed=None, save_position_history=True, model_validation=None, restore_best_position=True, prune_history=True)[source]#

Optimize the parameters of a Liesel Model.

Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. See below for details.

Params#

model_train

The Liesel model to optimize.

params

List of parameter names to optimize. All other parameters of the model are held fixed.

optimizer

An optimizer from the optax library. If None , optax.adam(learning_rate=1e-2) is used.

stopper

A Stopper that carries information about the maximum number of iterations and early stopping.

batch_size

The batch size. If None, the whole dataset is used for each optimization step.

batch-seed

Batches are assembled randomly in each iteration. This is the seed used for shuffling in this step.

save_position_history

If True, the position history is saved to the results object.

model_validation

If supplied, this model serves as a validation model, which means that early stopping is based on the negative log likelihood evaluated using the observed data in this model. If None, no early stopping is conducted.

restore_best_position

If True, the position with the lowest loss within the patience defined by the supplied Stopper is restored as the final postion. If False, the last iteration’s position is used.

prune_history

If True, the history is pruned to the length of the final iteration. This means, the history can be shorter than the maximum number of iterations defined by the supplied Stopper. If False, unused history entries are set to jax.numpy.nan if optimization stops early.

rtype:

OptimResult

returns:

A dataclass of type OptimResult, giving access to the results.

See also

history_to_df

A helper function to turn the OptimResult.history into a pandas.DataFrame - nice for quickly plotting results.

Notes

If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. This means that this function assumes that, for all Var objects in your model, it is valid to index their values like this:

var_object.value[batch_indices, ...]

The batching functionality also assumes that all objects that should be batched are included as Var objects with Var.observed set to True.

Examples

We show a minimal example. First, import tfd.

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Next, generate some data.

>>> key = jax.random.PRNGKey(42)
>>> key, subkey = jax.random.split(key)
>>> x = jax.random.normal(key, (100,))
>>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))

Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1.

>>> coef = lsl.param(jnp.zeros(2), name="coef")
>>> xvar = lsl.obs(jnp.c_[jnp.ones_like(x), x], name="x")
>>> mu = lsl.Var(lsl.Calc(jnp.dot, xvar, coef), name="mu")
>>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> yvar = lsl.obs(y, ydist, name="y")
>>> model = lsl.GraphBuilder().add(yvar).build_model()

Now, we are ready to run the optimization.

>>> result = gs.optim_flat(model, params=["coef"])
>>> {name: jnp.round(value, 2) for name, value in result.position.items()}
{'coef': Array([0.52, 1.29], dtype=float32)}

We can now, for example, use result.model_state in EngineBuilder.set_initial_values() to implement a “warm start” of MCMC sampling.