stack_leaves()#
- liesel.goose.pytree.stack_leaves(pytrees, axis=0)[source]#
Stacks all leaves in the list of pytrees along the given axis.
The function applies
jax.numpy.stack()
to all leaves. The stack operation creates a new axis.
Stacks all leaves in the list of pytrees along the given axis.
The function applies jax.numpy.stack()
to all leaves. The stack
operation creates a new axis.