liesel.goose.optim_flat(model_train, params, optimizer=None, stopper=Stopper(max_iter=10000, patience=10, atol=0.001, rtol=1e-12), batch_size=None, batch_seed=None, save_position_history=True, model_validation=None, restore_best_position=True, prune_history=True)[source]#

Optimize the parameters of a Liesel Model.

Approximates maximum a posteriori (MAP) parameter estimates by minimizing the negative log posterior probability of the model. If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. See below for details.



The Liesel model to optimize.


List of parameter names to optimize. All other parameters of the model are held fixed.


An optimizer from the optax library. If None , optax.adam(learning_rate=1e-2) is used.


A Stopper that carries information about the maximum number of iterations and early stopping.


The batch size. If None, the whole dataset is used for each optimization step.


Batches are assembled randomly in each iteration. This is the seed used for shuffling in this step.


If True, the position history is saved to the results object.


If supplied, this model serves as a validation model, which means that early stopping is based on the negative log likelihood evaluated using the observed data in this model. If None, no early stopping is conducted.


If True, the position with the lowest loss within the patience defined by the supplied Stopper is restored as the final postion. If False, the last iteration’s position is used.


If True, the history is pruned to the length of the final iteration. This means, the history can be shorter than the maximum number of iterations defined by the supplied Stopper. If False, unused history entries are set to jax.numpy.nan if optimization stops early.




A dataclass of type OptimResult, giving access to the results.

See also


A helper function to turn the OptimResult.history into a pandas.DataFrame - nice for quickly plotting results.


If you use batching, be aware that the batching functionality implemented here assumes a “flat” model structure. This means that this function assumes that, for all Var objects in your model, it is valid to index their values like this:

var_object.value[batch_indices, ...]

The batching functionality also assumes that all objects that should be batched are included as Var objects with Var.observed set to True.


We show a minimal example. First, import tfd.

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Next, generate some data.

>>> key = jax.random.PRNGKey(42)
>>> key, subkey = jax.random.split(key)
>>> x = jax.random.normal(key, (100,))
>>> y = 0.5 + 1.2 * x + jax.random.normal(subkey, (100,))

Next, set up a linear model. For simplicity, we assume the scale to be fixed to the true value of 1.

>>> coef = lsl.param(jnp.zeros(2), name="coef")
>>> xvar = lsl.obs(jnp.c_[jnp.ones_like(x), x], name="x")
>>> mu = lsl.Var(lsl.Calc(jnp.dot, xvar, coef), name="mu")
>>> ydist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> yvar = lsl.obs(y, ydist, name="y")
>>> model = lsl.GraphBuilder().add(yvar).build_model()

Now, we are ready to run the optimization.

>>> result = gs.optim_flat(model, params=["coef"])
>>> {name: jnp.round(value, 2) for name, value in result.position.items()}
{'coef': Array([0.52, 1.29], dtype=float32)}

We can now, for example, use result.model_state in EngineBuilder.set_initial_values() to implement a “warm start” of MCMC sampling.