Training Callbacks

    Overview

    A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument callbacks) to the fit() function. The relevant methods of the callbacks will then be called at each stage of the training.

    For example:

    library(keras)
    
    # generate dummy training data
    data <- matrix(rexp(1000*784), nrow = 1000, ncol = 784)
    labels <- matrix(round(runif(1000*10, min = 0, max = 9)), nrow = 1000, ncol = 10)
    
    # create model
    model <- keras_model_sequential() 
    
    # add layers and compile
    model %>%
      layer_dense(32, input_shape = c(784)) %>%
      layer_activation('relu') %>%
      layer_dense(10) %>%
      layer_activation('softmax') %>% 
      compile(
        loss='binary_crossentropy',
        optimizer = optimizer_sgd(),
        metrics='accuracy'
      )
      
    # fit with callbacks
    model %>% fit(data, labels, callbacks = list(
      callback_model_checkpoint("checkpoints.h5"),
      callback_reduce_lr_on_plateau(monitor = "val_loss", factor = 0.1)
    ))

    Built in Callbacks

    The following built-in callbacks are available as part of Keras:

    callback_progbar_logger()

    Callback that prints metrics to stdout.

    callback_model_checkpoint()

    Save the model after every epoch.

    callback_early_stopping()

    Stop training when a monitored quantity has stopped improving.

    callback_remote_monitor()

    Callback used to stream events to a server.

    callback_learning_rate_scheduler()

    Learning rate scheduler.

    callback_tensorboard()

    TensorBoard basic visualizations

    callback_reduce_lr_on_plateau()

    Reduce learning rate when a metric has stopped improving.

    callback_csv_logger()

    Callback that streams epoch results to a csv file

    callback_lambda()

    Create a custom callback

    Custom Callbacks

    You can create a custom callback by creating a new R6 class that inherits from the KerasCallback class.

    Here’s a simple example saving a list of losses over each batch during training:

    library(keras)
    
    # define custom callback class
    LossHistory <- R6::R6Class("LossHistory",
      inherit = KerasCallback,
      
      public = list(
        
        losses = NULL,
         
        on_batch_end = function(batch, logs = list()) {
          self$losses <- c(self$losses, logs[["loss"]])
        }
    ))
    
    # define model
    model <- keras_model_sequential() 
    
    # add layers and compile
    model %>% 
      layer_dense(units = 10, input_shape = c(784)) %>% 
      layer_activation(activation = 'softmax') %>% 
      compile(
        loss = 'categorical_crossentropy', 
        optimizer = 'rmsprop'
      )
    
    # create history callback object and use it during training
    history <- LossHistory$new()
    model %>% fit(
      X_train, Y_train,
      batch_size=128, epochs=20, verbose=0,
      callbacks= list(history)
    )
    
    # print the accumulated losses
    history$losses
    [1] 0.6604760 0.3547246 0.2595316 0.2590170 ...

    Fields

    Custom callback objects have access to the current model and it’s training parameters via the following fields:

    self$params

    Named list with training parameters (eg. verbosity, batch size, number of epochs…).

    self$model

    Reference to the Keras model being trained.

    Methods

    Custom callback objects can implement one or more of the following methods:

    on_epoch_begin(epoch, logs)

    Called at the beginning of each epoch.

    on_epoch_end(epoch, logs)

    Called at the end of each epoch.

    on_batch_begin(batch, logs)

    Called at the beginning of each batch.

    on_batch_end(batch, logs)

    Called at the end of each batch.

    on_train_begin(logs)

    Called at the beginning of training.

    on_train_end(logs)

    Called at the end of training.

    on_train_batch_begin

    Called at the beginning of every batch.

    on_train_batch_end

    Called at the end of every batch.`

    on_predict_batch_begin

    Called at the beginning of a batch in predict methods.

    on_predict_batch_end

    Called at the end of a batch in predict methods.

    on_predict_begin

    Called at the beginning of prediction.

    on_predict_end

    Called at the end of prediction.

    on_test_batch_begin

    Called at the beginning of a batch in evaluate methods. Also called at the beginning of a validation batch in the fit methods, if validation data is provided.

    on_test_batch_end

    Called at the end of a batch in evaluate methods. Also called at the end of a validation batch in the fit methods, if validation data is provided.

    on_test_begin

    Called at the beginning of evaluation or validation.

    on_test_end

    Called at the end of evaluation or validation.