NUTSKernel.tune()

NUTSKernel.tune()#

NUTSKernel.tune(prng_key, kernel_state, model_state, epoch, history)#

The method can perform automatic tuning of the kernel and is called after each adaptation epoch.

To tune the kernel, the method can return an altered kernel state.

Must be jittable.

Parameters:
  • prng_key (Any) – The key for JAX’ pseudo-random number generator.

  • model_state (Any) – Current model state.

  • kernel_state (TypeVar(TKernelState, bound= Any)) – Current kernel state.

  • epoch (EpochState) – Current epoch state.

  • history (Optional[NewType(Position, dict[str, Any])]) – Holds the history of the position of the current epoch, i.e., that is the position but each leave in the pytree is enhanced by one dimension (axis = 0) which represents the time or MCMC iteration. The value may be None if the class variable needs_histroy is False.

Return type:

TuningOutcome[TypeVar(TKernelState, bound= Any), TypeVar(TTuningInfo, bound= TuningInfo)]