squeeze_leaves()#
- liesel.goose.pytree.squeeze_leaves(pytree, axis=0)[source]#
Squeezes all leaves in a pytree.
The function applies
jax.numpy.squeeze()
to all leaves.
Squeezes all leaves in a pytree.
The function applies jax.numpy.squeeze()
to all leaves.