Var.new_calc()

Contents

Var.new_calc()#

classmethod Var.new_calc(function, *inputs, distribution=None, name='', _needs_seed=False, _update_on_init=True, **kwinputs)[source]#

Initializes a weak variable that is a function of other variables.

A calculating variable can wrap arbitrary calculations in pure JAX functions.

Tip

The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals.

Please consult the JAX docs for details.

Parameters:
  • function (Callable[..., Any]) – The function to be wrapped. Must be jit-compilable by JAX.

  • *inputs (Any) – Non-keyword inputs. Any inputs that are not already nodes or Var will be converted to Value nodes. The values of these inputs will be passed to the wrapped function in the same order they are entered here.

  • distribution (Dist | None) – The probability distribution of the variable. (default: None)

  • name (str) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a Model. (default: '')

  • _needs_seed (bool) – Whether the node needs a seed / PRNG key. (default: False)

  • _update_on_init (bool) – If True, the calculator will try to evaluate its function upon initialization. (default: True)

  • **kwinputs (Any) – Keyword inputs. Any inputs that are not already nodes or Var`s will be converted to :class:.Data` nodes. The values of these inputs will be passed to the wrapped function as keyword arguments.

Return type:

Var

Notes

Internally, this constructor initializes and wraps a Calc node.

See also

Var.new_param

Initializes a strong variable that acts as a model parameter.

Var.new_obs

Initializes a strong variable that holds observed data.

Var.new_value

Initializes a strong variable without a distribution.

Calc

The calculator node class.

Examples

A simple calculator node, taking the exponential value of an input parameter.

>>> log_scale = lsl.Var.new_param(0.0, name="log_scale")
>>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale")
>>> print(scale.value)
1.0

You can also use your own functions as long as they are jit-compilable by JAX.

>>> def compute_variance(x):
...     return jnp.exp(x)**2
>>> log_scale = lsl.Var.new_param(0.0, name="log_scale")
>>> variance = lsl.Var.new_calc(compute_variance, log_scale, name="scale")
>>> print(variance.value)
1.0

The value of the calculating variable is updated when update() is called.

>>> log_scale = lsl.Var.new_param(0.0, name="log_scale")
>>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale")
>>> print(scale.value)
1.0
>>> log_scale.value = 1.0
>>> print(scale.value)
1.0
>>> print(scale.update().value)
2.7182817