Writing your own callbacks

Guide to writing Keras callbacks for customizing the behavior during model training, evaluation, or inference.

Introduction

A callback is a powerful tool to customize the behavior of a Keras model during training, evaluation, or inference. Examples include callback_tensorboard() to visualize training progress and results with TensorBoard, or callback_model_checkpoint() to periodically save your model during training.

In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own. We provide a few demos of simple callback applications to get you started.

Setup

library(tensorflow)
library(keras)
envir::import_from(dplyr, last)

tf_version()
[1] '2.11'

Keras callbacks overview

All callbacks subclass the keras$callbacks$Callback class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training.

You can pass a list of callbacks (as a named argument callbacks) to the following keras model methods:

  • fit()
  • evaluate()
  • predict()

An overview of callback methods

Global methods

on_(train|test|predict)_begin(logs = NULL)

Called at the beginning of fit/evaluate/predict.

on_(train|test|predict)_end(logs = NULL)

Called at the end of fit/evaluate/predict.

Batch-level methods for training/testing/predicting

on_(train|test|predict)_batch_begin(logs = NULL)

Called right before processing a batch during training/testing/predicting.

on_(train|test|predict)_batch_end(batch, logs = NULL)

Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the metrics results.

Epoch-level methods (training only)

on_epoch_begin(epoch, logs = NULL)

Called at the beginning of an epoch during training.

on_epoch_end(epoch, logs = NULL)

Called at the end of an epoch during training.

A basic example

Let’s take a look at a concrete example. To get started, let’s import tensorflow and define a simple Sequential Keras model:

get_model <- function() {
  model <- keras_model_sequential() %>%
    layer_dense(1, input_shape = 784) %>%
    compile(
      optimizer = optimizer_rmsprop(learning_rate=0.1),
      loss = "mean_squared_error",
      metrics = "mean_absolute_error"
    )
  model
}

Then, load the MNIST data for training and testing from Keras datasets API:

mnist <- dataset_mnist()

flatten_and_rescale <- function(x) {
  x <- array_reshape(x, c(-1, 784))
  x <- x / 255
  x
}

mnist$train$x <- flatten_and_rescale(mnist$train$x)
mnist$test$x  <- flatten_and_rescale(mnist$test$x)

# limit to 500 samples
mnist$train$x <- mnist$train$x[1:500, ]
mnist$train$y <- mnist$train$y[1:500]
mnist$test$x  <- mnist$test$x[1:500, ]
mnist$test$y  <- mnist$test$y[1:500]

Now, define a simple custom callback that logs:

  • When fit/evaluate/predict starts & ends
  • When each epoch starts & ends
  • When each training batch starts & ends
  • When each evaluation (test) batch starts & ends
  • When each inference (prediction) batch starts & ends
show <- function(msg, logs) {
  cat(glue::glue(msg, .envir = parent.frame()),
      "got logs: ", sep = "; ")
  logs %>% 
    lapply(signif, digits = 3) %>% 
    dput(control = "niceNames")
}

CustomCallback(keras$callbacks$Callback) %py_class% {
  on_train_begin <- function(logs = NULL)
    show("Starting training", logs)

  on_train_end <- function(logs = NULL)
    show("Stop training", logs)

  on_epoch_begin <- function(epoch, logs = NULL)
    show("Start epoch {epoch} of training", logs)

  on_epoch_end <- function(epoch, logs = NULL)
    show("End epoch {epoch} of training", logs)

  on_test_begin <- function(logs = NULL)
    show("Start testing", logs)

  on_test_end <- function(logs = NULL)
    show("Stop testing", logs)

  on_predict_begin <- function(logs = NULL)
    show("Start predicting", logs)

  on_predict_end <- function(logs = NULL)
    show("Stop predicting", logs)

  on_train_batch_begin <- function(batch, logs = NULL)
    show("...Training: start of batch {batch}", logs)

  on_train_batch_end <- function(batch, logs = NULL)
    show("...Training: end of batch {batch}",  logs)

  on_test_batch_begin <- function(batch, logs = NULL)
    show("...Evaluating: start of batch {batch}", logs)

  on_test_batch_end <- function(batch, logs = NULL)
    show("...Evaluating: end of batch {batch}", logs)

  on_predict_batch_begin <- function(batch, logs = NULL)
    show("...Predicting: start of batch {batch}", logs)

  on_predict_batch_end <- function(batch, logs = NULL)
    show("...Predicting: end of batch {batch}", logs)
}

Let’s try it out:

model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 128,
  epochs = 2,
  verbose = 0,
  validation_split = 0.5,
  callbacks = list(CustomCallback())
)
Starting training; got logs: list()
Start epoch 0 of training; got logs: list()
...Training: start of batch 0; got logs: list()
...Training: end of batch 0; got logs: list(loss = 29.5, mean_absolute_error = 4.6)
...Training: start of batch 1; got logs: list()
...Training: end of batch 1; got logs: list(loss = 454, mean_absolute_error = 16)
Start testing; got logs: list()
...Evaluating: start of batch 0; got logs: list()
...Evaluating: end of batch 0; got logs: list(loss = 29, mean_absolute_error = 4.64)
...Evaluating: start of batch 1; got logs: list()
...Evaluating: end of batch 1; got logs: list(loss = 27.8, mean_absolute_error = 4.48)
Stop testing; got logs: list(loss = 27.8, mean_absolute_error = 4.48)
End epoch 0 of training; got logs: list(loss = 454, mean_absolute_error = 16, val_loss = 27.8, val_mean_absolute_error = 4.48)
Start epoch 1 of training; got logs: list()
...Training: start of batch 0; got logs: list()
...Training: end of batch 0; got logs: list(loss = 28.1, mean_absolute_error = 4.45)
...Training: start of batch 1; got logs: list()
...Training: end of batch 1; got logs: list(loss = 18.6, mean_absolute_error = 3.5)
Start testing; got logs: list()
...Evaluating: start of batch 0; got logs: list()
...Evaluating: end of batch 0; got logs: list(loss = 7.18, mean_absolute_error = 2.18)
...Evaluating: start of batch 1; got logs: list()
...Evaluating: end of batch 1; got logs: list(loss = 7.26, mean_absolute_error = 2.21)
Stop testing; got logs: list(loss = 7.26, mean_absolute_error = 2.21)
End epoch 1 of training; got logs: list(loss = 18.6, mean_absolute_error = 3.5, val_loss = 7.26, 
    val_mean_absolute_error = 2.21)
Stop training; got logs: list(loss = 18.6, mean_absolute_error = 3.5, val_loss = 7.26, 
    val_mean_absolute_error = 2.21)
res <- model %>%
  evaluate(
    mnist$test$x,
    mnist$test$y,
    batch_size = 128,
    verbose = 0,
    callbacks = list(CustomCallback())
  )
Start testing; got logs: list()
...Evaluating: start of batch 0; got logs: list()
...Evaluating: end of batch 0; got logs: list(loss = 7.24, mean_absolute_error = 2.19)
...Evaluating: start of batch 1; got logs: list()
...Evaluating: end of batch 1; got logs: list(loss = 6.74, mean_absolute_error = 2.13)
...Evaluating: start of batch 2; got logs: list()
...Evaluating: end of batch 2; got logs: list(loss = 6.65, mean_absolute_error = 2.13)
...Evaluating: start of batch 3; got logs: list()
...Evaluating: end of batch 3; got logs: list(loss = 6.96, mean_absolute_error = 2.18)
Stop testing; got logs: list(loss = 6.96, mean_absolute_error = 2.18)
res <- model %>%
  predict(mnist$test$x,
          batch_size = 128,
          callbacks = list(CustomCallback()))
Start predicting; got logs: list()
...Predicting: start of batch 0; got logs: list()
...Predicting: end of batch 0; got logs: list(outputs = c(3.61, 2.54, 1.69, 4.21, 2.99, 2.35, 3.61, 4.11, 
4.17, 6.08, 1.26, 2.57, 4.64, 2.6, 2.67, 2.44, 4.03, 3.74, 4.48, 
2.89, 5.22, 3.67, 2.66, 2.68, 2.33, 4.17, 2.8, 3.44, 3.28, 1.81, 
3.31, 1.85, 3.32, 2.57, 4.62, 2.51, 4.45, 2.2, 2.08, 3.13, 1.11, 
3.36, 4.35, 1.69, 2.34, 2.21, 2.44, 2.83, 6.24, 3.31, 2.98, 3.97, 
3.21, 1.26, 2.02, 3.27, 3.87, 1.83, 4.81, 2.06, 3.95, 4.19, 2.91, 
3.06, 5.01, 2.51, 2.88, 3.59, 4.77, 0.765, 3.01, 2.34, 2.28, 
5.03, 2.46, 4.43, 2.17, 2.68, 4.27, 7.27, 4.66, 4.27, 4.09, 3.63, 
4.37, 5.24, 5.34, 3.15, 4.09, 2.76, 2.21, 4.12, 2.74, 4.87, 3.69, 
6.26, 1.83, 4.68, 2.96, 5.42, 2.68, 1.9, 4.3, 3.65, 2.72, 3.88, 
2.42, 1.92, 4.22, 3.22, 3.22, 3.35, 2.48, 5.08, 3.65, 3.46, 3.16, 
3.9, 3.5, 4.35, 3.88, 3.29, 2.64, 4.32, 4.87, 3.53, 2.35, 4.2
))
...Predicting: start of batch 1; got logs: list()
...Predicting: end of batch 1; got logs: list(outputs = c(6.39, 3.42, 5.3, 5.04, 3.28, 5.02, 3.65, 3.08, 
2.49, 3.47, 3.69, 4.89, 2.66, 2.58, 2.98, 2.02, 4.87, 1.54, 5.03, 
4.42, 2.21, 2.55, 3.8, 4.52, 3.12, 2.29, 1.79, 4.28, 2.33, 2.89, 
2.79, 2.77, 4.67, 2.95, 4.2, 5.05, 3.12, 2.94, 3.56, 2.36, 2.14, 
3.6, 3.1, 5.51, 2.12, 1.79, 3.52, 1.73, 1.38, 3.98, 2.27, 3.98, 
2.13, 4.19, 2.75, 1.24, 3.31, 4.04, 4.4, 3.09, 1.94, 1.28, 3.66, 
1.65, 2.55, 3.62, 2.72, 2.06, 3.25, 4.92, 3.75, 4.94, 5.85, 3.06, 
1.72, 1.59, 1.71, 2.03, 5.81, 3.16, 3.14, 5.89, 2.77, 2.99, 5.42, 
2.41, 6.04, 3.92, 3.1, 3.89, 3.32, 2.73, 4.58, 3.65, 6.16, 5.01, 
0.931, 2.77, 4.88, 2.07, 1.93, 2.73, 2.87, 1.67, 4.2, 5.95, 2.61, 
3.27, 4, 3.67, 4.04, 2.12, 2.55, 3.84, 5.67, 3.32, 2.43, 1.53, 
2.04, 4.01, 5.12, 2.7, 5.95, 1.57, 5.86, 2.07, 3.13, 3.23))
...Predicting: start of batch 2; got logs: list()
...Predicting: end of batch 2; got logs: list(outputs = c(3.1, 3.72, 4.07, 2.88, 4.5, 2.39, 2.34, 4.2, 
3.33, 2.9, 4.67, 3.1, 2.41, 1.07, 3.17, 3.26, 2.09, 4.76, 5.23, 
4.59, 2.75, 5.72, 3.93, 1.73, 3.2, 4.68, 3.54, 2.77, 3.82, 3.5, 
3.53, 5.93, 2.72, 2.73, 4.49, 2.43, 5.38, 4.89, 2.42, 3.86, 2.21, 
0.52, 2.61, 3.11, 2.6, 5.24, 2.07, 2.12, 4.55, 0.958, 2.53, 3.82, 
5.34, 2.02, 2.01, 4.9, 3.56, 2.43, 2.3, 5.18, 3.5, 4.33, 2.63, 
2.71, 3.23, 3.53, 5.23, 2.02, 2.63, 3.35, 2.88, 3.34, 4.61, 2.9, 
2.3, 2.91, 2.42, 2.97, 1.79, 2.07, 4.76, 4.03, 4.27, 3.06, 3.14, 
2.28, 2.33, 2.38, 3.02, 3.17, 3.59, 5.76, 2.48, 3.7, 2.56, 3.74, 
2.6, 6.39, 1.96, 3.9, 4.02, 1.28, 3, 4.81, 3.21, 4.47, 3.88, 
1.73, 3.18, 2.07, 3.66, 2.66, 2.96, 1.87, 4.11, 4.09, 1.4, 4.65, 
4.24, 4.05, 3.9, 2, 1.9, 3.46, 0.831, 1.89, 3.66, 2.58))
...Predicting: start of batch 3; got logs: list()
...Predicting: end of batch 3; got logs: list(outputs = c(5.03, 1.49, 2.96, 2.33, 1.71, 5.39, 4.6, 4.23, 
3.8, 2.3, 2.95, 2.96, 4.26, 1.66, 3.16, 3.24, 3.01, 2.74, 2.6, 
3.49, 2.91, 5.34, 4.7, 3.11, 2.95, 2.72, 3.74, 3.39, 6.3, 5.65, 
3.43, 4.18, 2.73, 4.71, 3.26, 1.73, 2.44, 1.8, 4.36, 4.02, 1.03, 
4.07, 3.84, 1.89, 3.88, 4.1, 2.11, 3.48, 2.81, 3.84, 3.85, 4.28, 
3.83, 6, 3.82, 3, 4.54, 2.91, 3.56, 1.66, 3.45, 1.64, 4.28, 5.8, 
4.26, 3.32, 4.95, 5.13, 2.77, 1.45, 4.56, 2.99, 2.04, 2.75, 4.8, 
2.31, 4.18, 4.14, 5.25, 4.42, 4.5, 5.71, 3.7, 3.68, 5.03, 3.73, 
3.2, 3.22, 3.83, 1.2, 5.63, 4.94, 2.59, 2.89, 2.4, 3.15, 1.28, 
4.33, 2.15, 2.08, 3.52, 1.95, 4.88, 4.49, 4.5, 1.72, 1.58, 2.56, 
2.8, 2.91, 4.71, 5.91, 3.7, 3.56, 3.32, 2.14))
Stop predicting; got logs: list()

Usage of logs

The logs named list contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error.

LossAndErrorPrintingCallback(keras$callbacks$Callback) %py_class% {
  on_train_batch_end <- function(batch, logs = NULL)
    cat(sprintf("Up to batch %i, the average loss is %7.2f.\n",
                batch,  logs$loss))

  on_test_batch_end <- function(batch, logs = NULL)
    cat(sprintf("Up to batch %i, the average loss is %7.2f.\n",
                batch, logs$loss))

  on_epoch_end <- function(epoch, logs = NULL)
    cat(sprintf(
      "The average loss for epoch %2i is %9.2f and mean absolute error is %7.2f.\n",
      epoch, logs$loss, logs$mean_absolute_error
    ))
}

model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 128,
  epochs = 2,
  verbose = 0,
  callbacks = list(LossAndErrorPrintingCallback())
)
Up to batch 0, the average loss is   27.64.
Up to batch 1, the average loss is  425.30.
Up to batch 2, the average loss is  291.55.
Up to batch 3, the average loss is  226.01.
The average loss for epoch  0 is    226.01 and mean absolute error is    9.61.
Up to batch 0, the average loss is    6.48.
Up to batch 1, the average loss is    6.42.
Up to batch 2, the average loss is    5.89.
Up to batch 3, the average loss is    5.88.
The average loss for epoch  1 is      5.88 and mean absolute error is    2.00.
res = model %>% evaluate(
  mnist$test$x,
  mnist$test$y,
  batch_size = 128,
  verbose = 0,
  callbacks = list(LossAndErrorPrintingCallback())
)
Up to batch 0, the average loss is    5.37.
Up to batch 1, the average loss is    4.78.
Up to batch 2, the average loss is    4.71.
Up to batch 3, the average loss is    4.80.

Usage of self$model attribute

In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: self$model.

Here are of few of the things you can do with self$model in a callback:

  • Set self$model$stop_training <- TRUE to immediately interrupt training.
  • Mutate hyperparameters of the optimizer (available as self$model$optimizer), such as self$model$optimizer$learning_rate.
  • Save the model at period intervals.
  • Record the output of predict(model) on a few test samples at the end of each epoch, to use as a sanity check during training.
  • Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.
  • etc.

Let’s see this in action in a couple of examples.

Examples of Keras callback applications

Early stopping at minimum loss

This first example shows the creation of a Callback that stops training when the minimum of loss has been reached, by setting the attribute self$model$stop_training (boolean). Optionally, you can provide an argument patience to specify how many epochs we should wait before stopping after having reached a local minimum.

keras$callbacks$EarlyStopping provides a more complete and general implementation.

EarlyStoppingAtMinLoss(keras$callbacks$Callback) %py_class% {
  "Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
        number of no improvement, training stops.
  "

  initialize <- function(patience = 0) {
    # call keras$callbacks$Callback$__init__(), so it can setup `self`
    super$initialize()
    self$patience <- patience
    # best_weights to store the weights at which the minimum loss occurs.
    self$best_weights <- NULL
  }

  on_train_begin <- function(logs = NULL) {
    # The number of epoch it has waited when loss is no longer minimum.
    self$wait <- 0
    # The epoch the training stops at.
    self$stopped_epoch <- 0
    # Initialize the best as infinity.
    self$best <- Inf
  }

  on_epoch_end <- function(epoch, logs = NULL) {
    current <- logs$loss
    if (current < self$best) {
      self$best <- current
      self$wait <- 0
      # Record the best weights if current results is better (less).
      self$best_weights <- self$model$get_weights()
    } else {
      self$wait %<>% `+`(1)
      if (self$wait >= self$patience) {
        self$stopped_epoch <- epoch
        self$model$stop_training <- TRUE
        cat("Restoring model weights from the end of the best epoch.\n")
        self$model$set_weights(self$best_weights)
      }
    }
  }

  on_train_end <- function(logs = NULL)
    if (self$stopped_epoch > 0)
      cat(sprintf("Epoch %05d: early stopping\n", self$stopped_epoch + 1))

}


model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 64,
  steps_per_epoch = 5,
  epochs = 30,
  verbose = 0,
  callbacks = list(LossAndErrorPrintingCallback(),
                   EarlyStoppingAtMinLoss())
)
Up to batch 0, the average loss is   27.95.
Up to batch 1, the average loss is  516.37.
Up to batch 2, the average loss is  352.82.
Up to batch 3, the average loss is  266.31.
Up to batch 4, the average loss is  214.07.
The average loss for epoch  0 is    214.07 and mean absolute error is    8.43.
Up to batch 0, the average loss is    7.93.
Up to batch 1, the average loss is    7.23.
Up to batch 2, the average loss is    7.00.
Up to batch 3, the average loss is    6.57.
Up to batch 4, the average loss is    6.06.
The average loss for epoch  1 is      6.06 and mean absolute error is    2.05.
Up to batch 0, the average loss is    5.44.
Up to batch 1, the average loss is    4.97.
Up to batch 2, the average loss is    4.84.
Up to batch 3, the average loss is    4.47.
Up to batch 4, the average loss is    4.41.
The average loss for epoch  2 is      4.41 and mean absolute error is    1.69.
Up to batch 0, the average loss is    5.61.
Up to batch 1, the average loss is    6.26.
Up to batch 2, the average loss is    6.78.
Up to batch 3, the average loss is    7.74.
Up to batch 4, the average loss is    8.84.
The average loss for epoch  3 is      8.84 and mean absolute error is    2.49.
Restoring model weights from the end of the best epoch.
Epoch 00004: early stopping

Learning rate scheduling

In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.

See keras$callbacks$LearningRateScheduler for a more general implementations (in RStudio, press F1 while the cursor is over LearningRateScheduler and a browser will open to this page).

CustomLearningRateScheduler(keras$callbacks$Callback) %py_class% {
  "Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  "

  `__init__` <- function(schedule) {
    super()$`__init__`()
    self$schedule <- schedule
  }

  on_epoch_begin <- function(epoch, logs = NULL) {
    ## When in doubt about what types of objects are in scope (e.g., self$model)
    ## use a debugger to interact with the actual objects at the console!
    # browser()

    if (!"learning_rate" %in% names(self$model$optimizer))
      stop('Optimizer must have a "learning_rate" attribute.')

    # # Get the current learning rate from model's optimizer.
    # use as.numeric() to convert the tf.Variable to an R numeric
    lr <- as.numeric(self$model$optimizer$learning_rate)
    # # Call schedule function to get the scheduled learning rate.
    scheduled_lr <- self$schedule(epoch, lr)
    # # Set the value back to the optimizer before this epoch starts
    self$model$optimizer$learning_rate <- scheduled_lr
    cat(sprintf("\nEpoch %05d: Learning rate is %6.4f.\n", epoch, scheduled_lr))
  }
}


LR_SCHEDULE <- tibble::tribble(~ start_epoch, ~ learning_rate,
                                           0,           0.1  ,
                                           3,           0.05 ,
                                           6,           0.01 ,
                                           9,           0.005,
                                          12,           0.001)


lr_schedule <- function(epoch, learning_rate) {
  "Helper function to retrieve the scheduled learning rate based on epoch."
  if (epoch <= last(LR_SCHEDULE$start_epoch))
    with(LR_SCHEDULE, learning_rate[which.min(epoch > start_epoch)])
  else
    learning_rate
}


model <- get_model()
model %>% fit(
  mnist$train$x,
  mnist$train$y,
  batch_size = 64,
  steps_per_epoch = 5,
  epochs = 15,
  verbose = 0,
  callbacks = list(
    LossAndErrorPrintingCallback(),
    CustomLearningRateScheduler(lr_schedule)
  )
)

Epoch 00000: Learning rate is 0.1000.
Up to batch 0, the average loss is   24.61.
Up to batch 1, the average loss is  380.05.
Up to batch 2, the average loss is  261.02.
Up to batch 3, the average loss is  199.01.
Up to batch 4, the average loss is  160.59.
The average loss for epoch  0 is    160.59 and mean absolute error is    7.77.

Epoch 00001: Learning rate is 0.0500.
Up to batch 0, the average loss is    5.31.
Up to batch 1, the average loss is    5.03.
Up to batch 2, the average loss is    5.19.
Up to batch 3, the average loss is    5.24.
Up to batch 4, the average loss is    5.07.
The average loss for epoch  1 is      5.07 and mean absolute error is    1.88.

Epoch 00002: Learning rate is 0.0500.
Up to batch 0, the average loss is    4.79.
Up to batch 1, the average loss is    4.51.
Up to batch 2, the average loss is    4.30.
Up to batch 3, the average loss is    4.37.
Up to batch 4, the average loss is    4.44.
The average loss for epoch  2 is      4.44 and mean absolute error is    1.69.

Epoch 00003: Learning rate is 0.0500.
Up to batch 0, the average loss is    4.66.
Up to batch 1, the average loss is    4.12.
Up to batch 2, the average loss is    3.99.
Up to batch 3, the average loss is    4.16.
Up to batch 4, the average loss is    4.36.
The average loss for epoch  3 is      4.36 and mean absolute error is    1.70.

Epoch 00004: Learning rate is 0.0100.
Up to batch 0, the average loss is    4.32.
Up to batch 1, the average loss is    4.12.
Up to batch 2, the average loss is    3.60.
Up to batch 3, the average loss is    3.41.
Up to batch 4, the average loss is    3.27.
The average loss for epoch  4 is      3.27 and mean absolute error is    1.48.

Epoch 00005: Learning rate is 0.0100.
Up to batch 0, the average loss is    3.74.
Up to batch 1, the average loss is    3.19.
Up to batch 2, the average loss is    3.21.
Up to batch 3, the average loss is    3.49.
Up to batch 4, the average loss is    3.46.
The average loss for epoch  5 is      3.46 and mean absolute error is    1.53.

Epoch 00006: Learning rate is 0.0100.
Up to batch 0, the average loss is    2.32.
Up to batch 1, the average loss is    3.63.
Up to batch 2, the average loss is    3.38.
Up to batch 3, the average loss is    3.38.
Up to batch 4, the average loss is    3.39.
The average loss for epoch  6 is      3.39 and mean absolute error is    1.46.

Epoch 00007: Learning rate is 0.0050.
Up to batch 0, the average loss is    3.59.
Up to batch 1, the average loss is    3.19.
Up to batch 2, the average loss is    3.32.
Up to batch 3, the average loss is    3.30.
Up to batch 4, the average loss is    3.26.
The average loss for epoch  7 is      3.26 and mean absolute error is    1.44.

Epoch 00008: Learning rate is 0.0050.
Up to batch 0, the average loss is    3.11.
Up to batch 1, the average loss is    3.13.
Up to batch 2, the average loss is    3.14.
Up to batch 3, the average loss is    3.12.
Up to batch 4, the average loss is    3.19.
The average loss for epoch  8 is      3.19 and mean absolute error is    1.42.

Epoch 00009: Learning rate is 0.0050.
Up to batch 0, the average loss is    3.27.
Up to batch 1, the average loss is    3.13.
Up to batch 2, the average loss is    3.09.
Up to batch 3, the average loss is    3.02.
Up to batch 4, the average loss is    3.02.
The average loss for epoch  9 is      3.02 and mean absolute error is    1.41.

Epoch 00010: Learning rate is 0.0010.
Up to batch 0, the average loss is    2.80.
Up to batch 1, the average loss is    2.99.
Up to batch 2, the average loss is    3.15.
Up to batch 3, the average loss is    3.25.
Up to batch 4, the average loss is    3.20.
The average loss for epoch 10 is      3.20 and mean absolute error is    1.41.

Epoch 00011: Learning rate is 0.0010.
Up to batch 0, the average loss is    2.45.
Up to batch 1, the average loss is    2.91.
Up to batch 2, the average loss is    2.73.
Up to batch 3, the average loss is    2.89.
Up to batch 4, the average loss is    2.97.
The average loss for epoch 11 is      2.97 and mean absolute error is    1.40.

Epoch 00012: Learning rate is 0.0010.
Up to batch 0, the average loss is    2.29.
Up to batch 1, the average loss is    2.17.
Up to batch 2, the average loss is    2.80.
Up to batch 3, the average loss is    2.93.
Up to batch 4, the average loss is    2.93.
The average loss for epoch 12 is      2.93 and mean absolute error is    1.34.

Epoch 00013: Learning rate is 0.0010.
Up to batch 0, the average loss is    3.36.
Up to batch 1, the average loss is    3.03.
Up to batch 2, the average loss is    3.18.
Up to batch 3, the average loss is    3.09.
Up to batch 4, the average loss is    2.87.
The average loss for epoch 13 is      2.87 and mean absolute error is    1.35.

Epoch 00014: Learning rate is 0.0010.
Up to batch 0, the average loss is    4.15.
Up to batch 1, the average loss is    3.32.
Up to batch 2, the average loss is    2.85.
Up to batch 3, the average loss is    2.68.
Up to batch 4, the average loss is    2.83.
The average loss for epoch 14 is      2.83 and mean absolute error is    1.32.

Built-in Keras callbacks

Be sure to check out the existing Keras callbacks by reading the API docs. Applications include logging to CSV, saving the model, visualizing metrics in TensorBoard, and a lot more!

Environment Details

tensorflow::tf_config()
TensorFlow v2.11.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.2.1 (2022-06-23)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS

Matrix products: default
BLAS:   /home/tomasz/opt/R-4.2.1/lib/R/lib/libRblas.so
LAPACK: /usr/lib/x86_64-linux-gnu/libmkl_intel_lp64.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] keras_2.9.0.9000      tensorflow_2.9.0.9000

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.9           pillar_1.8.1         compiler_4.2.1      
 [4] envir_0.2.2          base64enc_0.1-3      tools_4.2.1         
 [7] zeallot_0.1.0        digest_0.6.31        jsonlite_1.8.4      
[10] evaluate_0.18        lifecycle_1.0.3      tibble_3.1.8        
[13] lattice_0.20-45      pkgconfig_2.0.3      png_0.1-8           
[16] rlang_1.0.6          Matrix_1.5-3         DBI_1.1.3           
[19] cli_3.4.1            yaml_2.3.6           xfun_0.35           
[22] fastmap_1.1.0        dplyr_1.0.10         stringr_1.5.0       
[25] knitr_1.41           generics_0.1.3       vctrs_0.5.1         
[28] htmlwidgets_1.5.4    rprojroot_2.0.3      tidyselect_1.2.0    
[31] grid_4.2.1           here_1.0.1           reticulate_1.26-9000
[34] glue_1.6.2           R6_2.5.1             fansi_1.0.3         
[37] rmarkdown_2.18       magrittr_2.0.3       whisker_0.4.1       
[40] htmltools_0.5.4      tfruns_1.5.1         assertthat_0.2.1    
[43] utf8_1.2.2           stringi_1.7.8       
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.3.0
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
cachetools==5.2.0
certifi==2022.12.7
charset-normalizer==2.1.1
decorator==5.1.1
dill==0.3.6
etils==0.9.0
executing==1.2.0
flatbuffers==22.12.6
gast==0.4.0
google-auth==2.15.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.57.0
grpcio==1.51.1
h5py==3.7.0
idna==3.4
importlib-resources==5.10.1
ipython==8.7.0
jedi==0.18.2
kaggle==1.5.12
keras==2.11.0
keras-tuner==1.1.3
kt-legacy==1.0.4
libclang==14.0.6
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib-inline==0.1.6
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.2
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
promise==2.3
prompt-toolkit==3.0.36
protobuf==3.19.6
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pydot==1.4.2
Pygments==2.13.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-slugify==7.0.0
pytz==2022.6
PyYAML==6.0
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.9.3
six==1.16.0
stack-data==0.6.2
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.12.0
termcolor==2.1.1
text-unidecode==1.3
toml==0.10.2
tqdm==4.64.1
traitlets==5.7.1
typing_extensions==4.4.0
urllib3==1.26.13
wcwidth==0.2.5
Werkzeug==2.2.2
wrapt==1.14.1
zipp==3.11.0
TF Devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU') 
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU') 
CPU cores: 12 
Date rendered: 2022-12-16 
Page render time: 10 seconds