Distributed training with Keras


    The tf$distribute$Strategy API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

    This tutorial uses the tf$distribute$MirroredStrategy, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model’s variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.

    MirroredStategy is one of several distribution strategy available in TensorFlow core.

    Keras API

    This example uses the keras API to build the model and training loop.

    # used to load the MNIST dataset

    Download the dataset

    Download the MNIST dataset and load it using tfds. This returns a dataset in tfdatasets format.

    # if you haven't done it yet:
    # tfds::install_tfds()
    mnist <- tfds_load("mnist")
    info <- summary(mnist)

    Define distribution strategy

    Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf$distribute$MirroredStrategy$scope) to build your model inside.

    strategy <- tf$distribute$MirroredStrategy()

    Setup input pipeline

    When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

    num_train_examples <- as.integer(info$splits[[2]]$statistics$numExamples)
    num_test_examples <- as.integer(info$splits[[1]]$statistics$numExamples)
    BUFFER_SIZE <- 10000
    BATCH_SIZE <- BATCH_SIZE_PER_REPLICA * strategy$num_replicas_in_sync

    Pixel values, which are 0-255, have to be normalized to the 0-1 range. Furthermore, we shuffle and batch the train and test datasets. Notice we are also keeping an in-memory cache of the training data to improve performance.

    train_dataset <- mnist$train %>% 
      dataset_map(function(record) {
        record$image <- tf$cast(record$image, tf$float32) / 255
        record}) %>%
      dataset_cache() %>%
      dataset_shuffle(BUFFER_SIZE) %>% 
      dataset_batch(BATCH_SIZE) %>% 
    test_dataset <- mnist$test %>% 
     dataset_map(function(record) {
        record$image <- tf$cast(record$image, tf$float32) / 255
        record}) %>%
      dataset_batch(BATCH_SIZE) %>% 

    Create the model

    Create and compile the Keras model in the context of strategy$scope.

    with (strategy$scope(), {
       model <- keras_model_sequential() %>%
           filters = 32,
           kernel_size = 3,
           activation = 'relu',
           input_shape = c(28, 28, 1)
           ) %>%
         layer_max_pooling_2d() %>%
         layer_flatten() %>%
         layer_dense(units = 64, activation = 'relu') %>%
         layer_dense(units = 10, activation = 'softmax')
      model %>% compile(
        loss = 'sparse_categorical_crossentropy',
        optimizer = 'adam',
        metrics = 'accuracy')

    Define the callbacks

    The callbacks used here are:

    • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
    • Model Checkpoint: This callback saves the model after every epoch.
    • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

    For illustrative purposes, add a print callback to display the learning rate.

    # Define the checkpoint directory to store the checkpoints
    checkpoint_dir <- './training_checkpoints'
    # Name of the checkpoint files
    checkpoint_prefix <- file.path(checkpoint_dir, "ckpt_{epoch}")
    # Function for decaying the learning rate.
    # You can define any decay function you need.
    decay <- function(epoch, lr) {
      if (epoch < 3) 1e-3
        else if (epoch >= 3 && epoch < 7) 1e-4
          else 1e-5
    # Callback for printing the LR at the end of each epoch.
    PrintLR <- R6::R6Class("PrintLR",
      inherit = KerasCallback,
      public = list(
        losses = NULL,
        on_epoch_end = function(epoch, logs = list()) {
          tf$print(glue('\nLearning rate for epoch {epoch} is {as.numeric(model$optimizer$lr)}\n'))
    print_lr <- PrintLR$new()
    callbacks <- list(
        callback_tensorboard(log_dir = '/tmp/logs'),
        callback_model_checkpoint(filepath = checkpoint_prefix, save_weights_only = TRUE),

    Train and evaluate

    Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

    model %>% fit(train_dataset, epochs = 12, callbacks = callbacks)

    As you can see below, the checkpoints are getting saved.


    To see how the model performs, load the latest checkpoint and call evaluate on the test data.

    model %>% load_model_weights_tf(tf$train$latest_checkpoint(checkpoint_dir))
    model %>% evaluate(test_dataset)
    tensorboard(log_dir = "/tmp/logs")

    Export to SavedModel

    Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

    path <- 'saved_model/'
    model %>% save_model_tf(path)

    Load the model without strategy$scope.

    unreplicated_model <- load_model_tf(path)
    unreplicated_model %>% compile(
        loss = 'sparse_categorical_crossentropy',
        optimizer = 'adam',
        metrics = 'accuracy')
    unreplicated_model %>% evaluate(test_dataset)

    Load the model with strategy$scope.

    with (strategy$scope(), {
      replicated_model <- load_model_tf(path)
      replicated_model %>% compile(
        loss = 'sparse_categorical_crossentropy',
        optimizer = 'adam',
        metrics = 'accuracy')
      replicated_model %>% evaluate(test_dataset)