split_leaves()#

liesel.goose.pytree.split_leaves(pytree, indices_or_sections, axis=0)[source]#

Splits all leaves in a pytree into multiple sub-arrays.

The function applies jax.numpy.split() to all leaves.