liesel.goose.pytree.as_strong_pytree#

liesel.goose.pytree.as_strong_pytree(pytree)[source]#

Converts every leaf in a pytree to a non-weak jax.numpy.DeviceArray.

See https://jax.readthedocs.io/en/latest/type_promotion.html.

Return type

TypeVar(T)