Var.transform()#
- Var.transform(bijector=None, *bijector_args, inference=None, name=None, **bijector_kwargs)[source]#
Transforms the variable, making it a function of a new variable.
Creates a new variable on the unconstrained space
R**nwith the appropriate transformed distribution, turning the original variable into a weak variable without an associated distribution. The transformation is performed using TFP’s bijector classes; see the TFP bijectors documentation.- Stored transformed variable: After the transformation, you can
access the transformed variable via
bijected_var.
- Parameters:
bijector (
type[tfp.substrates.jax.bijectors.Bijector] |tfp.substrates.jax.bijectors.Bijector|None, default:None) – The bijector used to map the new transformed variable to this variable (forward transformation). IfNone, the experimental default event space bijector (see tensorflow probability documentation) is used. If a bijector class is passed, it is instantiated with the argumentsbijector_argsandbijector_kwargs. If a bijector instance is passed, it is used directly.bijector_args – The arguments passed on to the init function of the bijector.
inference (
TypeAliasType|Literal['drop'], default:None) – Additional information that can be used to set up inference algorithms for the new, transformed variable. If"drop", the inference information will be dropped from the original variable. The new variable will have no inference information. IfNone(default), the new variable will likewise have no inference information, but an error will be raised if there is inference information on the original variable.name (
str|None, default:None) – Name for the new, transformed variable. IfNone(default), the new name will be<old_name>_transformed, where<old_name>is a placeholder for the current variable’s name.bijector_kwargs – The keyword arguments passed on to the init function of the bijector.
- Return type:
- Returns:
The new transformed variable which acts as an input to this variable.
- Raises:
RuntimeError – If the variable is weak or if the variable has no distribution.
ValueError – If the argument
bijectorisNone, but the distribution does not have a default event space bijector. Also, if in the arguments totransform()isinference=Nonebut the variable attributeinferenceis notNone.
See also
bijectSimilar method, but with a slightly different API and returns self instead of the transformed variable.
Notes
This is a simplified pseudo-code illustration of what this method does:
import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd def transform(original_var: lsl.Var, bijector: tfb.Bijector): original_dist = original_var.dist_node.distribution dist_inputs = original_var.dist_node.inputs # transform the distribution new_dist = tfd.TransformedDistribution( original_dist, tfb.Invert(bijector) ) # transform initial value new_value = bijector.inverse(original_var.value) # initialise the new variable new_var = lsl.Var( new_value, lsl.Dist(new_dist, *dist_inputs), name=f"{original_var.name}_transformed" ) new_var.parameter = original_var.parameter # define the original variable as a function of the new variable original_var.value_node = lsl.Calc(bijector.forward, new_var) original_var.parameter = False # return the new variable return new_var
The value of the attribute
parameteris transferred to the transformed variable and set toFalseon the original variable. The attributesobservedandrolehave the default values for the transformed variable and remain unchanged on the original variable.Examples
>>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import tensorflow_probability.substrates.jax.bijectors as tfb
Assume we have a variable
scalethat is constrained to be positive, and we want to include the log-transformation of this variable in the model. We first set up the parameter var with its distribution:>>> prior = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0) >>> scale = lsl.Var.new_param(1.0, prior, name="scale")
The we transform the variable to the log-scale:
>>> log_scale = scale.transform(tfb.Exp()) >>> log_scale Var(name="scale_transformed")
Now the
log_scalehas a log probability, and thescalevariable does not:>>> log_scale.update().log_prob Array(-3.6720574, dtype=float32)
>>> scale.update().log_prob 0.0