Var.new_calc()#
- classmethod Var.new_calc(function, *inputs, dist=None, name='', _needs_seed=False, _update_on_init=True, convert_inputs='default', cache=True, distribution=None, **kwinputs)[source]#
Initializes a weak variable that is a function of other variables.
A calculating variable can wrap arbitrary calculations in pure JAX functions.
Tip
The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals.
Please consult the JAX docs for details.
- Parameters:
function (
Callable[...,Any]) – The function to be wrapped. Must be jit-compilable by JAX.*inputs (
Any) – Non-keyword inputs. Any inputs that are not already nodes orVarwill be converted toValuenodes. The values of these inputs will be passed to the wrapped function in the same order they are entered here.dist (
Dist|None, default:None) – The probability distribution of the variable.name (
str, default:'') – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of aModel._needs_seed (
bool, default:False) – Whether the node needs a seed / PRNG key._update_on_init (
bool, default:True) – IfTrue, the calculator will try to evaluate its function upon initialization.convert_inputs (
Callable[[Any],Any] |Literal['default'], default:'default') – A function used to process the values of this variable’s inputs. The default uses the function stored inconvert_value, which isjax.numpy.asarray.cache (
bool, default:True) – IfFalse, this variable will not store a cache of its value. This means, thefunctionis evaluated every single time that the value of this variable is requested. This can save memory, if the computations are trivial (such as prepending an axis to an array), but it can greatly slow down computations otherwise (such as when the function performs a matrix inversion). Internally, ifcache=True, this variable wraps aCalc, and ifcache=False, it wraps aTransientCalc.distribution (
Dist|None, default:None) – Deprecated argument name for the probability distribution of the variable, kept for backwards-compatibility. Please use the new namedist.**kwinputs (
Any) – Keyword inputs. Any inputs that are not already nodes orVar`s will be converted to :class:.Data` nodes. The values of these inputs will be passed to the wrapped function as keyword arguments.
- Return type:
Notes
Internally, this constructor initializes and wraps a
Calcnode.See also
Var.new_paramInitializes a strong variable that acts as a model parameter.
Var.new_obsInitializes a strong variable that holds observed data.
Var.new_valueInitializes a strong variable without a distribution.
CalcThe calculator node class.
Examples
A simple calculator node, taking the exponential value of an input parameter.
>>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale") >>> print(scale.value) 1.0
You can also use your own functions as long as they are jit-compilable by JAX.
>>> def compute_variance(x): ... return jnp.exp(x)**2 >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> variance = lsl.Var.new_calc(compute_variance, log_scale, name="scale") >>> print(variance.value) 1.0
The value of the calculating variable is updated when
update()is called.>>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale") >>> print(scale.value) 1.0 >>> log_scale.value = 1.0 >>> print(scale.value) 1.0 >>> print(scale.update().value) 2.7182817