slice_leaves()#
- liesel.goose.pytree.slice_leaves(pytree, idx)[source]#
Performs the same slice operation on every leaf.
idxcan be constructed withjax.numpy.s_ornumpy.s_, for example:>>> jnp.s_[0] 0
>>> jnp.s_[0:3, :, 2] (slice(0, 3, None), slice(None, None, None), 2)