Var.transform()

Contents

Var.transform()#

Var.transform(bijector=None, *bijector_args, inference=None, name=None, **bijector_kwargs)[source]#

Transforms the variable, making it a function of a new variable.

Creates a new variable on the unconstrained space R**n with the appropriate transformed distribution, turning the original variable into a weak variable without an associated distribution. The transformation is performed using TFP’s bijector classes; see the TFP bijectors documentation.

Parameters:
  • bijector (type[Bijector] | Bijector | None) – The bijector used to map the new transformed variable to this variable (forward transformation). If None, the experimental default event space bijector (see tensorflow probability documentation) is used. If a bijector class is passed, it is instantiated with the arguments bijector_args and bijector_kwargs. If a bijector instance is passed, it is used directly. (default: None)

  • bijector_args – The arguments passed on to the init function of the bijector.

  • inference (Union[TypeAliasType, Literal['drop']]) – Additional information that can be used to set up inference algorithms for the new, transformed variable. If "drop", the inference information will be dropped from the original variable. The new variable will have no inference information. If None (default), the new variable will likewise have no inference information, but an error will be raised if there is inference information on the original variable. (default: None)

  • name (str | None) – Name for the new, transformed variable. If None (default), the new name will be <old_name>_transformed, where <old_name> is a placeholder for the current variable’s name. (default: None)

  • bijector_kwargs – The keyword arguments passed on to the init function of the bijector.

Return type:

Var

Returns:

The new transformed variable which acts as an input to this variable.

Raises:
  • RuntimeError – If the variable is weak or if the variable has no distribution.

  • ValueError – If the argument bijector is None, but the distribution does not have a default event space bijector. Also, if in the arguments to transform() is inference=None but the variable attribute inference is not None.

Notes

This is a simplified pseudo-code illustration of what this method does:

import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

def transform(original_var: lsl.Var, bijector: tfb.Bijector):
    original_dist = original_var.dist_node.distribution
    dist_inputs = original_var.dist_node.inputs

    # transform the distribution
    new_dist = tfd.TransformedDistribution(
        original_dist, tfb.Invert(bijector)
    )

    # transform initial value
    new_value = bijector.inverse(original_var.value)

    # initialise the new variable
    new_var = lsl.Var(
        new_value,
        lsl.Dist(new_dist, *dist_inputs),
        name=f"{original_var.name}_transformed"
    )
    new_var.parameter = original_var.parameter

    # define the original variable as a function of the new variable
    original_var.value_node = lsl.Calc(bijector.forward, new_var)
    original_var.parameter = False

    # return the new variable
    return new_var

The value of the attribute parameter is transferred to the transformed variable and set to False on the original variable. The attributes observed and role have the default values for the transformed variable and remain unchanged on the original variable.

Examples

>>> import tensorflow_probability.substrates.jax.distributions as tfd
>>> import tensorflow_probability.substrates.jax.bijectors as tfb

Assume we have a variable scale that is constrained to be positive, and we want to include the log-transformation of this variable in the model. We first set up the parameter var with its distribution:

>>> prior = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0)
>>> scale = lsl.Var.new_param(1.0, prior, name="scale")

The we transform the variable to the log-scale:

>>> log_scale = scale.transform(tfb.Exp())
>>> log_scale
Var(name="scale_transformed")

Now the log_scale has a log probability, and the scale variable does not:

>>> log_scale.update().log_prob
Array(-3.6720574, dtype=float32)
>>> scale.update().log_prob
0.0