liesel.goose.pytree.stack_leaves#
- liesel.goose.pytree.stack_leaves(pytrees, axis=0)[source]#
Stacks all leaves in the list of pytrees along the given axis.
The function applies
jax.numpy.stack()
to all leaves. The stack operation creates a new axis.