concatenate_leaves()#
- liesel.goose.pytree.concatenate_leaves(pytrees, axis=0)[source]#
Concatenates all leaves in the list of pytrees along the given axis.
The function applies
jax.numpy.concatenate()
to all leaves. The concatenate operation does not create a new axis.