liesel.goose.pytree.slice_leaves#

liesel.goose.pytree.slice_leaves(pytree, idx)[source]#

Performs the same slice operation on every leaf.

idx can be constructed with jax.numpy.s_ or numpy.s_, for example:

>>> jnp.s_[0]
0
>>> jnp.s_[0:3, :, 2]
(slice(0, 3, None), slice(None, None, None), 2)