Var.transform()#
- Var.transform(bijector=None, *bijector_args, inference=None, name=None, **bijector_kwargs)[source]#
Transforms the variable, making it a function of a new variable.
Creates a new variable on the unconstrained space
R**n
with the appropriate transformed distribution, turning the original variable into a weak variable without an associated distribution. The transformation is performed using TFP’s bijector classes; see the TFP bijectors documentation.- Parameters:
bijector (
type
[Bijector
] |Bijector
|None
) – The bijector used to map the new transformed variable to this variable (forward transformation). IfNone
, the experimental default event space bijector (see tensorflow probability documentation) is used. If a bijector class is passed, it is instantiated with the argumentsbijector_args
andbijector_kwargs
. If a bijector instance is passed, it is used directly. (default:None
)bijector_args – The arguments passed on to the init function of the bijector.
inference (
Union
[TypeAliasType
,Literal
['drop'
]]) – Additional information that can be used to set up inference algorithms for the new, transformed variable. If"drop"
, the inference information will be dropped from the original variable. The new variable will have no inference information. IfNone
(default), the new variable will likewise have no inference information, but an error will be raised if there is inference information on the original variable. (default:None
)name (
str
|None
) – Name for the new, transformed variable. IfNone
(default), the new name will be<old_name>_transformed
, where<old_name>
is a placeholder for the current variable’s name. (default:None
)bijector_kwargs – The keyword arguments passed on to the init function of the bijector.
- Return type:
- Returns:
The new transformed variable which acts as an input to this variable.
- Raises:
RuntimeError – If the variable is weak or if the variable has no distribution.
ValueError – If the argument
bijector
isNone
, but the distribution does not have a default event space bijector. Also, if in the arguments totransform()
isinference=None
but the variable attributeinference
is notNone
.
Notes
This is a simplified pseudo-code illustration of what this method does:
import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd def transform(original_var: lsl.Var, bijector: tfb.Bijector): original_dist = original_var.dist_node.distribution dist_inputs = original_var.dist_node.inputs # transform the distribution new_dist = tfd.TransformedDistribution( original_dist, tfb.Invert(bijector) ) # transform initial value new_value = bijector.inverse(original_var.value) # initialise the new variable new_var = lsl.Var( new_value, lsl.Dist(new_dist, *dist_inputs), name=f"{original_var.name}_transformed" ) new_var.parameter = original_var.parameter # define the original variable as a function of the new variable original_var.value_node = lsl.Calc(bijector.forward, new_var) original_var.parameter = False # return the new variable return new_var
The value of the attribute
parameter
is transferred to the transformed variable and set toFalse
on the original variable. The attributesobserved
androle
have the default values for the transformed variable and remain unchanged on the original variable.Examples
>>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import tensorflow_probability.substrates.jax.bijectors as tfb
Assume we have a variable
scale
that is constrained to be positive, and we want to include the log-transformation of this variable in the model. We first set up the parameter var with its distribution:>>> prior = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0) >>> scale = lsl.Var.new_param(1.0, prior, name="scale")
The we transform the variable to the log-scale:
>>> log_scale = scale.transform(tfb.Exp()) >>> log_scale Var(name="scale_transformed")
Now the
log_scale
has a log probability, and thescale
variable does not:>>> log_scale.update().log_prob Array(-3.6720574, dtype=float32)
>>> scale.update().log_prob 0.0