liesel.goose.pytree#

Pytree utilities.

Functions

as_strong_pytree(pytree)

Converts every leaf in a pytree to a non-weak jax.numpy.DeviceArray.

concatenate_leaves(pytrees[, axis])

Concatenates all leaves in the list of pytrees along the given axis.

register_dataclass_as_pytree(cls)

Decorator for registering dataclasses as pytrees.

slice_leaves(pytree, idx)

Performs the same slice operation on every leaf.

split_and_transpose(pytree[, axis])

Splits the leaves in a pytree along one axis and transposes the tree such that it's a list of pytrees.

split_leaves(pytree, indices_or_sections[, axis])

Splits all leaves in a pytree into multiple sub-arrays.

squeeze_leaves(pytree[, axis])

Squeezes all leaves in a pytree.

stack_leaves(pytrees[, axis])

Stacks all leaves in the list of pytrees along the given axis.