DistRegBuilder.transform()

DistRegBuilder.transform()#

DistRegBuilder.transform(var, bijector=None, *args, **kwargs)#

Transforms a variable by adding a new transformed variable as an input.

Creates a new variable on the unconstrained space R**n with the appropriate transformed distribution, turning the original variable into a weak variable without an associated distribution. The transformation is performed using TFP’s bijector classes.

The value of the attribute parameter is transferred to the transformed variable and set to False on the original variable. The attributes observed and role are set to the default values for the transformed variable and remain unchanged on the original variable.

Parameters:
  • var (Var) – The variable to transform (and add to the graph).

  • bijector (Optional[type[Union[Bijector, Bijector]]]) – The bijector used to map the new transformed variable to this variable (forward transformation). If None, the experimental default event space bijector (see TFP documentation) is used. (default: None)

  • args – The arguments passed on to the init function of the bijector.

  • kwargs – The keyword arguments passed on to the init function of the bijector.

Return type:

Var

Returns:

The new transformed variable which acts as an input to this variable.

Raises:

RuntimeError – If the variable is weak, has no TFP distribution, the distribution has no default event space bijector and the argument bijector is None, or the local model for the variable cannot be built.

Examples

>>> import tensorflow_probability.substrates.jax.distributions as tfd
>>> import tensorflow_probability.substrates.jax.bijectors as tfb

Assume we have a variable scale that is constrained to be positive, and we want to include the log-transformation of this variable in the model. We first set up the parameter var with its distribution:

>>> prior = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0)
>>> scale = lsl.param(1.0, prior, name="scale")

Then we create a GraphBuilder and use the transform method to transform the scale variable.

>>> gb = lsl.GraphBuilder()
>>> log_scale = gb.transform(scale, bijector=tfb.Exp)
>>> log_scale
Var(name="scale_transformed")

Now the log_scale has a log probability, and the scale variable is has not:

>>> log_scale.update().log_prob
Array(-3.6720574, dtype=float32)
>>> scale.update().log_prob
0.0