Training Callbacks

Overview

A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument callbacks) to the fit() function. The relevant methods of the callbacks will then be called at each stage of the training.

For example:

Built in Callbacks

The following built-in callbacks are available as part of Keras:

callback_progbar_logger()

Callback that prints metrics to stdout.

callback_model_checkpoint()

Save the model after every epoch.

callback_early_stopping()

Stop training when a monitored quantity has stopped improving.

callback_remote_monitor()

Callback used to stream events to a server.

callback_learning_rate_scheduler()

Learning rate scheduler.

callback_tensorboard()

TensorBoard basic visualizations

callback_reduce_lr_on_plateau()

Reduce learning rate when a metric has stopped improving.

callback_csv_logger()

Callback that streams epoch results to a csv file

callback_lambda()

Create a custom callback

Custom Callbacks

You can create a custom callback by creating a new R6 class that inherits from the KerasCallback class.

Here’s a simple example saving a list of losses over each batch during training:

[1] 0.6604760 0.3547246 0.2595316 0.2590170 ...

Fields

Custom callback objects have access to the current model and it’s training parameters via the following fields:

self$params

Named list with training parameters (eg. verbosity, batch size, number of epochs…).

self$model

Reference to the Keras model being trained.

Methods

Custom callback objects can implement one or more of the following methods:

on_epoch_begin(epoch, logs)

Called at the beginning of each epoch.

on_epoch_end(epoch, logs)

Called at the end of each epoch.

on_batch_begin(batch, logs)

Called at the beginning of each batch.

on_batch_end(batch, logs)

Called at the end of each batch.

on_train_begin(logs)

Called at the beginning of training.

on_train_end(logs)

Called at the end of training.