liesel.goose.pytree.split_leaves#
- liesel.goose.pytree.split_leaves(pytree, indices_or_sections, axis=0)[source]#
Splits all leaves in a pytree into multiple sub-arrays.
The function applies
jax.numpy.split()
to all leaves.
Splits all leaves in a pytree into multiple sub-arrays.
The function applies jax.numpy.split()
to all leaves.