KernelSequence.transition()#

KernelSequence.transition(prng_key, kernel_states, model_state, epoch)[source]#

Handles one transition. Must be jittable.

Return type:

KerSeqTransitionOutput