liesel.goose.pytree.concatenate_leaves#

liesel.goose.pytree.concatenate_leaves(pytrees, axis=0)[source]#

Concatenates all leaves in the list of pytrees along the given axis.

The function applies jax.numpy.concatenate() to all leaves. The concatenate operation does not create a new axis.