liesel.goose.pytree.squeeze_leaves#

liesel.goose.pytree.squeeze_leaves(pytree, axis=0)[source]#

Squeezes all leaves in a pytree.

The function applies jax.numpy.squeeze() to all leaves.