Calc#
- class liesel.model.Calc(function, *inputs, _name='', _needs_seed=False, _update_on_init=True, convert_inputs='default', **kwinputs)[source]#
Bases:
NodeA
Nodesubclass that calculates its value based on its inputs nodes.Calculator nodes are a central element of the Liesel graph building toolkit. They wrap arbitrary calculations in pure JAX functions.
By default, calculator nodes will appear in the node graph created by
viz.plot_nodes(), but they will not appear in the model graph created byviz.plot_vars().You can use
new_calc()if you want your calculation to be treated as a model variable and thus be shown inviz.plot_vars().
Tip
The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals.
Please consult the JAX docs for details.
- Parameters:
function (
Callable[...,Any]) – The function to be wrapped. Must be jit-compilable by JAX.*inputs (
Any) – Non-keyword inputs. Any inputs that are not already nodes orVarwill be converted toValuenodes. The values of these inputs will be passed to the wrapped function in the same order they are entered here._name (
str, default:'') – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of aModel._needs_seed (
bool, default:False) – Whether the node needs a seed / PRNG key._update_on_init (
bool, default:True) – IfTrue, the calculator will try to evaluate its function upon initialization.convert_inputs (
Callable[[Any],Any] |Literal['default'], default:'default') – A function used to process the values of this node’s inputs. The default uses the function stored inconvert_value, which isjax.numpy.asarray.**kwinputs (
Any) – Keyword inputs. Any inputs that are not already nodes orVar`s will be converted to :class:.Value` nodes. The values of these inputs will be passed to the wrapped function as keyword arguments.
See also
Var.new_calcInitializes a weak variable that is a function of other variables.
VarA variable in a statistical model, typically with a probability distribution.
Var.new_paramInitializes a strong variable that acts as a model parameter.
Var.new_obsInitializes a strong variable that holds observed data.
Var.new_valueInitializes a strong variable without a distribution.
ValueA node representing some static data.
DistA node representing a
tensorflow_probabilityDistribution.
Examples
A simple calculator node, taking the exponential value of an input parameter.
>>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Calc(jnp.exp, log_scale) >>> print(scale.value) 1.0
The value of the calculator node is updated when
Calc.update()is called.>>> scale.update() Calc(name="") >>> print(scale.value) 1.0
You can also use your own functions as long as they are jit-compilable by JAX.
>>> def compute_variance(x): ... return jnp.exp(x)**2 >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> variance = lsl.Calc(compute_variance, log_scale).update() >>> print(variance.value) 1.0
Methods
update()Updates the value of the node.
Attributes