Kernel.start_epoch()

Kernel.start_epoch()#

abstractmethod Kernel.start_epoch(prng_key, kernel_state, model_state, epoch)[source]#

Called at the beginning of an epoch. Must be jittable.

Return type:

TypeVar(TKernelState, bound= Any)