liesel.goose.pytree.slice_leaves
liesel.goose.pytree.slice_leaves#
- liesel.goose.pytree.slice_leaves(pytree, idx)[source]#
Performs the same slice operation on every leaf.
idx
can be constructed withjax.numpy.s_
ornumpy.s_
, for example:>>> jnp.s_[0] 0
>>> jnp.s_[0:3, :, 2] (slice(0, 3, None), slice(None, None, None), 2)