Distributed training with Keras


    The tf$distribute$Strategy API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

    This tutorial uses the tf$distribute$MirroredStrategy, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model’s variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.

    MirroredStategy is one of several distribution strategy available in TensorFlow core. You can read about more strategies in the distribution strategy guide.

    Keras API

    This example uses the keras API to build the model and training loop. For custom training loops, see the Custom training loops tutorial.

    # used to load the MNIST dataset

    Download the dataset

    Download the MNIST dataset and load it using tfds. This returns a dataset in tfdatasets format.

    Define distribution strategy

    Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf$distribute$MirroredStrategy$scope) to build your model inside.

    Setup input pipeline

    When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

    Pixel values, which are 0-255, have to be normalized to the 0-1 range. Furthermore, we shuffle and batch the train and test datasets. Notice we are also keeping an in-memory cache of the training data to improve performance.

    Create the model

    Create and compile the Keras model in the context of strategy$scope.

    Define the callbacks

    The callbacks used here are:

    • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
    • Model Checkpoint: This callback saves the model after every epoch.
    • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

    For illustrative purposes, add a print callback to display the learning rate.

    callbacks <- list(
        callback_tensorboard(log_dir = '/tmp/logs'),
        callback_model_checkpoint(filepath = checkpoint_prefix, save_weights_only = TRUE),

    Train and evaluate

    Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

    model %>% fit(train_dataset, epochs = 12, callbacks = callbacks)

    As you can see below, the checkpoints are getting saved.


    To see how the model performs, load the latest checkpoint and call evaluate on the test data.

    model %>% load_model_weights_tf(tf$train$latest_checkpoint(checkpoint_dir))
    model %>% evaluate(test_dataset)
    tensorboard(log_dir = "/tmp/logs")

    Export to SavedModel

    Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

    Load the model without strategy$scope.

    Load the model with strategy$scope.

    Examples and Tutorials

    Here are some examples for using distribution strategy with keras fit/compile:

    1. Transformer example trained using tf.distribute.MirroredStrategy
    2. NCF example trained using tf.distribute.MirroredStrategy

    Call evaluate as before using appropriate datasets.