Calc

Contents

Calc#

class liesel.model.nodes.Calc(function, *inputs, _name='', _needs_seed=False, update_on_init=True, **kwinputs)[source]#

Bases: Node

A Node subclass that calculates its value based on its inputs nodes.

Calculator nodes are a central element block of the Liesel graph building toolkit. They wrap arbitrary calculations in pure JAX functions.

  • By default, calculator nodes will appear in the node graph created by viz.plot_nodes(), but they will not appear in the model graph created by viz.plot_vars().

  • You can wrap a calculator node in a Var to make it appear in the model graph.

Tip

The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals.

Please consult the JAX docs for details.

Parameters:
  • function (Callable[..., Any]) – The function to be wrapped. Must be jit-compilable by JAX.

  • *inputs (Any) – Non-keyword inputs. Any inputs that are not already nodes or Var will be converted to Data nodes. The values of these inputs will be passed to the wrapped function in the same order they are entered here.

  • _name (str) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a Model. (default: '')

  • _needs_seed (bool) – Whether the node needs a seed / PRNG key. (default: False)

  • update_on_init (bool) – If True, the calculator will try to evaluate its function upon initialization. (default: True)

  • **kwinputs (Any) – Keyword inputs. Any inputs that are not already nodes or Var`s will be converted to :class:.Data` nodes. The values of these inputs will be passed to the wrapped function as keyword arguments.

See also

Data

A node representing some static data.

Dist

A node representing a tensorflow_probability Distribution.

Var

A variable in a statistical model, typically with a probability distribution.

param

A helper function to initialize a Var as a model parameter.

obs

A helper function to initialize a Var as an observed variable.

Notes

A calculator node will compute its value only when Calc.update() is called. This does not happen automatically upon initialization. Commonly, the first time this method is called is during the initialization of a Model, which might make it hard to spot errors in the wrapped computations. To update the value immediately, you can call Calc.update() manually.

Examples

A simple calculator node, taking the exponential value of an input parameter.

>>> log_scale = lsl.param(0.0, name="log_scale")
>>> scale = lsl.Calc(jnp.exp, log_scale)
>>> print(scale.value)
1.0

The value of the calculator node is updated when Calc.update() is called.

>>> scale.update()
Calc(name="")
>>> print(scale.value)
1.0

You can also update the value of the calculator node in one step upon initilization.

>>> log_scale = lsl.param(0.0, name="log_scale")
>>> scale = lsl.Calc(jnp.exp, log_scale).update()
>>> print(scale.value)
1.0

You can also use your own functions as long as they are jit-compilable by JAX.

>>> def compute_variance(x):
...     return jnp.exp(x)**2
>>> log_scale = lsl.param(0.0, name="log_scale")
>>> variance = lsl.Calc(compute_variance, log_scale).update()
>>> print(variance.value)
1.0

You can wrap a calculator node in a Var to declare its role as a statistical model variable and make it appear in the variable graph.

>>> log_scale = lsl.param(0.0, name="log_scale")
>>> scale = lsl.Var(lsl.Calc(jnp.exp, log_scale).update())
>>> print(scale.value)
1.0

Methods

add_inputs(*inputs, **kwinputs)

Adds non-keyword and keyword input nodes to the existing ones.

all_input_nodes()

Returns all non-keyword and keyword input nodes as a unique tuple.

all_output_nodes()

Returns all output nodes as a unique tuple.

clear_state()

Clears the state of the node.

flag_outdated()

Flags the node and its recursive outputs as outdated.

set_inputs(*inputs, **kwinputs)

Sets the non-keyword and keyword input nodes.

update()

Updates the value of the node.

Attributes

function

The wrapped function.

groups

The groups that this node is a part of.

inputs

The non-keyword input nodes.

kwinputs

The keyword input nodes.

model

The model the node is part of.

name

The name of the node.

needs_seed

Whether the node needs a seed / PRNG key.

outdated

Whether the node is outdated.

outputs

The output nodes.

state

The state of the node.

value

The value of the node.

var

The variable the node is part of.

monitor

Whether the node should be monitored by an inference algorithm.