Source code for liesel.goose.pytree

"""
Pytree utilities.
"""

import dataclasses
from typing import TypeVar

import jax
import jax.numpy as jnp
import jax.tree_util

T = TypeVar("T")


[docs]def register_dataclass_as_pytree(cls): """Decorator for registering dataclasses as pytrees.""" if not dataclasses.is_dataclass(cls): raise TypeError(f"{cls} must be a dataclass") def flatten(cls_instance): # don't use dataclasses.asdict() here, because it converts nested dataclasses # to dicts recursively return jax.tree_util.tree_flatten(cls_instance.__dict__) def unflatten(aux_data, children): d = jax.tree_util.tree_unflatten(aux_data, children) rv = cls.__new__(cls) rv.__dict__.update(d) return rv jax.tree_util.register_pytree_node(cls, flatten, unflatten) return cls
[docs]def slice_leaves(pytree, idx): """ Performs the same slice operation on every leaf. ``idx`` can be constructed with :obj:`jax.numpy.s_` or :obj:`numpy.s_`, for example: >>> jnp.s_[0] 0 >>> jnp.s_[0:3, :, 2] (slice(0, 3, None), slice(None, None, None), 2) """ return jax.tree_util.tree_map(lambda x: x[idx], pytree)
[docs]def stack_leaves(pytrees, axis=0): """ Stacks all leaves in the list of pytrees along the given axis. The function applies :func:`jax.numpy.stack` to all leaves. The stack operation creates a new axis. """ return jax.tree_util.tree_map( lambda *xs: jnp.stack(xs, axis=axis), *pytrees, )
[docs]def concatenate_leaves(pytrees, axis=0): """ Concatenates all leaves in the list of pytrees along the given axis. The function applies :func:`jax.numpy.concatenate` to all leaves. The concatenate operation does not create a new axis. """ return jax.tree_util.tree_map(lambda *xs: jnp.concatenate(xs, axis=axis), *pytrees)
[docs]def split_leaves(pytree, indices_or_sections, axis=0): """ Splits all leaves in a pytree into multiple sub-arrays. The function applies :func:`jax.numpy.split` to all leaves. """ return jax.tree_util.tree_map( lambda x: jnp.split(x, indices_or_sections, axis), pytree )
[docs]def squeeze_leaves(pytree, axis=0): """ Squeezes all leaves in a pytree. The function applies :func:`jax.numpy.squeeze` to all leaves. """ return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis), pytree)
[docs]def split_and_transpose(pytree, axis=0): """ Splits the leaves in a pytree along one axis and transposes the tree such that it's a list of pytrees. It assumes that all leaves have the same dimensionality along the chosen axis. """ dim = jax.tree_util.tree_leaves(pytree)[0].shape[axis] spytree = split_leaves(pytree, dim, axis=axis) td_inner = jax.tree_util.tree_structure([0 for _ in range(dim)]) td_outer = jax.tree_util.tree_structure(pytree) return jax.tree_util.tree_transpose(td_outer, td_inner, spytree)
[docs]def as_strong_pytree(pytree: T) -> T: """ Converts every leaf in a pytree to a non-weak :obj:`jax.numpy.DeviceArray`. See https://jax.readthedocs.io/en/latest/type_promotion.html. """ return jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.asarray(x).dtype), pytree )