Var.new_calc()#
- classmethod Var.new_calc(function, *inputs, distribution=None, name='', _needs_seed=False, _update_on_init=True, **kwinputs)[source]#
Initializes a weak variable that is a function of other variables.
A calculating variable can wrap arbitrary calculations in pure JAX functions.
Tip
The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals.
Please consult the JAX docs for details.
- Parameters:
function (
Callable
[...
,Any
]) – The function to be wrapped. Must be jit-compilable by JAX.*inputs (
Any
) – Non-keyword inputs. Any inputs that are not already nodes orVar
will be converted toValue
nodes. The values of these inputs will be passed to the wrapped function in the same order they are entered here.distribution (
Dist
|None
) – The probability distribution of the variable. (default:None
)name (
str
) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of aModel
. (default:''
)_needs_seed (
bool
) – Whether the node needs a seed / PRNG key. (default:False
)_update_on_init (
bool
) – IfTrue
, the calculator will try to evaluate its function upon initialization. (default:True
)**kwinputs (
Any
) – Keyword inputs. Any inputs that are not already nodes orVar`s will be converted to :class:
.Data` nodes. The values of these inputs will be passed to the wrapped function as keyword arguments.
- Return type:
Notes
Internally, this constructor initializes and wraps a
Calc
node.See also
Var.new_param
Initializes a strong variable that acts as a model parameter.
Var.new_obs
Initializes a strong variable that holds observed data.
Var.new_value
Initializes a strong variable without a distribution.
Calc
The calculator node class.
Examples
A simple calculator node, taking the exponential value of an input parameter.
>>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale") >>> print(scale.value) 1.0
You can also use your own functions as long as they are jit-compilable by JAX.
>>> def compute_variance(x): ... return jnp.exp(x)**2 >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> variance = lsl.Var.new_calc(compute_variance, log_scale, name="scale") >>> print(variance.value) 1.0
The value of the calculating variable is updated when
update()
is called.>>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale") >>> print(scale.value) 1.0 >>> log_scale.value = 1.0 >>> print(scale.value) 1.0 >>> print(scale.update().value) 2.7182817