liesel.goose.pytree.stack_leaves#

liesel.goose.pytree.stack_leaves(pytrees, axis=0)[source]#

Stacks all leaves in the list of pytrees along the given axis.

The function applies jax.numpy.stack() to all leaves. The stack operation creates a new axis.