API Reference
Splits the leaves in a pytree along one axis and transposes the tree such that it’s a list of pytrees.
It assumes that all leaves have the same dimensionality along the chosen axis.