Node

Contents

Node#

class liesel.model.nodes.Node(*inputs, _name='', _needs_seed=False, **kwinputs)[source]#

Bases: ABC

A node of a computational graph that can cache its value.

Liesel represents statistical models as directed acyclic graphs (DAGs) of random variables (see Var) and computational nodes. The graph of random variables is built on top of the computational graph. The nodes of the computational graph will typically express computations in JAX returning arrays or pytrees, but in general, they can represent arbitrary operations in Python.

Nodes can cache the result of the operations they represent, improving the efficiency of the graph. The cached values are part of the model state (see Model.state), and can be stored in a chain by Liesel’s MCMC engine Goose.

Note

This class is an abstract class that cannot be initialized without defining the update() method. See below for the most important concrete node classes.

Parameters:
  • inputs (Any) – Non-keyword inputs. Any inputs that are not already nodes or Var will be converted to Data nodes.

  • _name (str) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a Model. (default: '')

  • _needs_seed (bool) – Whether the node needs a seed / PRNG key. (default: False)

See also

Calc

A node representing a general calculation/operation in JAX or Python.

Data

A node representing some static data.

Dist

A node representing a tensorflow_probability Distribution.

Var

A variable in a statistical model, typically with a probability distribution.

param

A helper function to initialize a Var as a model parameter.

obs

A helper function to initialize a Var as an observed variable.

Methods

add_inputs(*inputs, **kwinputs)

Adds non-keyword and keyword input nodes to the existing ones.

all_input_nodes()

Returns all non-keyword and keyword input nodes as a unique tuple.

all_output_nodes()

Returns all output nodes as a unique tuple.

clear_state()

Clears the state of the node.

flag_outdated()

Flags the node and its recursive outputs as outdated.

set_inputs(*inputs, **kwinputs)

Sets the non-keyword and keyword input nodes.

update()

Updates the value of the node.

Attributes

groups

The groups that this node is a part of.

inputs

The non-keyword input nodes.

kwinputs

The keyword input nodes.

model

The model the node is part of.

name

The name of the node.

needs_seed

Whether the node needs a seed / PRNG key.

outdated

Whether the node is outdated.

outputs

The output nodes.

state

The state of the node.

value

The value of the node.

var

The variable the node is part of.

monitor

Whether the node should be monitored by an inference algorithm.