Tutorial: Save and Restore Models

    Model progress can be saved after as well as during training. This means a model can resume where it left off and avoid long training times. Saving also means you can share your model and others can recreate your work. When publishing research models and techniques, most machine learning practitioners share:

    • code to create the model, and
    • the trained weights, or parameters, for the model

    Sharing this data helps others understand how the model works and try it themselves with new data.

    Options

    There are many different ways to save TensorFlow models—depending on the API you’re using. This guide uses Keras, a high-level API to build and train models in TensorFlow. For other approaches, see the TensorFlow Save and Restore guide or Saving in eager.

    Setup

    We’ll use the MNIST dataset to train our model to demonstrate saving weights. To speed up these demonstration runs, only use the first 1000 examples:

    Define a model

    Let’s build a simple model we’ll use to demonstrate saving and loading weights.

    ## Model: "sequential"
    ## ___________________________________________________________________________
    ## Layer (type)                     Output Shape                  Param #     
    ## ===========================================================================
    ## dense (Dense)                    (None, 512)                   401920      
    ## ___________________________________________________________________________
    ## dropout (Dropout)                (None, 512)                   0           
    ## ___________________________________________________________________________
    ## dense_1 (Dense)                  (None, 10)                    5130        
    ## ===========================================================================
    ## Total params: 407,050
    ## Trainable params: 407,050
    ## Non-trainable params: 0
    ## ___________________________________________________________________________

    Save the entire model

    Call save_model_* to save the a model’s architecture, weights, and training configuration in a single file/folder. This allows you to export a model so it can be used without access to the original code*. Since the optimizer-state is recovered, you can resume training from exactly where you left off.

    Saving a fully-functional model is very useful—you can load them in TensorFlow.js (HDF5, Saved Model) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite (HDF5, Saved Model)

    *Custom objects (e.g. subclassed models or layers) require special attention when saving and loading. See the “Saving custom objects” section below.

    SavedModel format

    The SavedModel format is a way to serialize models. Models saved in this format can be restored using load_model_tf and are compatible with TensorFlow Serving. The SavedModel guide goes into detail about how to serve/inspect the SavedModel. The section below illustrates the steps to saving and restoring the model.

    model <- create_model()
    
    model %>% fit(train_images, train_labels, epochs = 5, verbose = 2)
    ## Train on 1000 samples
    ## Epoch 1/5
    ## 1000/1000 - 0s - loss: 1.1809 - accuracy: 0.6680
    ## Epoch 2/5
    ## 1000/1000 - 0s - loss: 0.4156 - accuracy: 0.8860
    ## Epoch 3/5
    ## 1000/1000 - 0s - loss: 0.2836 - accuracy: 0.9250
    ## Epoch 4/5
    ## 1000/1000 - 0s - loss: 0.2241 - accuracy: 0.9370
    ## Epoch 5/5
    ## 1000/1000 - 0s - loss: 0.1473 - accuracy: 0.9680

    The SavedModel format is a directory containing a protobuf binary and a Tensorflow checkpoint. Inspect the saved model directory:

    list.files("model")
    ## [1] "assets"         "saved_model.pb" "variables"

    Reload a fresh Keras model from the saved model:

    ## Model: "sequential_1"
    ## ___________________________________________________________________________
    ## Layer (type)                     Output Shape                  Param #     
    ## ===========================================================================
    ## dense_2 (Dense)                  (None, 512)                   401920      
    ## ___________________________________________________________________________
    ## dropout_1 (Dropout)              (None, 512)                   0           
    ## ___________________________________________________________________________
    ## dense_3 (Dense)                  (None, 10)                    5130        
    ## ===========================================================================
    ## Total params: 407,050
    ## Trainable params: 407,050
    ## Non-trainable params: 0
    ## ___________________________________________________________________________

    HDF5 format

    Keras provides a basic saving format using the HDF5 standard.

    model <- create_model()
    
    model %>% fit(train_images, train_labels, epochs = 5, verbose = 2)
    ## Train on 1000 samples
    ## Epoch 1/5
    ## 1000/1000 - 0s - loss: 1.1386 - accuracy: 0.6780
    ## Epoch 2/5
    ## 1000/1000 - 0s - loss: 0.4326 - accuracy: 0.8770
    ## Epoch 3/5
    ## 1000/1000 - 0s - loss: 0.2874 - accuracy: 0.9310
    ## Epoch 4/5
    ## 1000/1000 - 0s - loss: 0.2164 - accuracy: 0.9460
    ## Epoch 5/5
    ## 1000/1000 - 0s - loss: 0.1536 - accuracy: 0.9690

    Now recreate the model from that file:

    new_model <- load_model_hdf5("my_model.h5")
    summary(new_model)
    ## Model: "sequential_2"
    ## ___________________________________________________________________________
    ## Layer (type)                     Output Shape                  Param #     
    ## ===========================================================================
    ## dense_4 (Dense)                  (None, 512)                   401920      
    ## ___________________________________________________________________________
    ## dropout_2 (Dropout)              (None, 512)                   0           
    ## ___________________________________________________________________________
    ## dense_5 (Dense)                  (None, 10)                    5130        
    ## ===========================================================================
    ## Total params: 407,050
    ## Trainable params: 407,050
    ## Non-trainable params: 0
    ## ___________________________________________________________________________

    This technique saves everything:

    • The weight values
    • The model’s configuration(architecture)
    • The optimizer configuration

    Keras saves models by inspecting the architecture. Currently, it is not able to save TensorFlow optimizers (from tf$train). When using those you will need to re-compile the model after loading, and you will lose the state of the optimizer.

    Saving custom objects

    If you are using the SavedModel format, you can skip this section. The key difference between HDF5 and SavedModel is that HDF5 uses object configs to save the model architecture, while SavedModel saves the execution graph.

    Thus, SavedModels are able to save custom objects like subclassed models and custom layers without requiring the orginal code.

    To save custom objects to HDF5, you must do the following:

    1. Define a get_config method in your object, and optionally a from_config classmethod.
      • get_config() returns a JSON-serializable dictionary of parameters needed to recreate the object.
      • from_config(config) uses the returned config from get_config to create a new object. By default, this function will use the config as initialization arguments.
    2. Pass the object to the custom_objects argument when loading the model. The argument must be a named list mapping the string class name to the class definition. E.g. load_keras_model_hdf5(path, custom_objects=list("CustomLayer" = CustomLayer))

    See the Writing layers and models from scratch tutorial for examples of custom_objects and get_config.

    Save checkpoints during training

    It is useful to automatically save checkpoints during and at the end of training. This way you can use a trained model without having to retrain it, or pick-up training where you left of, in case the training process was interrupted.

    callback_model_checkpoint is a callback that performs this task.

    The callback takes a couple of arguments to configure checkpointing. By default, save_weights_only is set to false, which means the complete model is being saved - including architecture and configuration. You can then restore the model as outlined in the previous paragraph.

    Now here, let’s focus on just saving and restoring weights. In the following code snippet, we are setting save_weights_only to true, so we will need the model definition on restore.

    Checkpoint callback usage

    Train the model and pass it the callback_model_checkpoint:

    ## Train on 1000 samples, validate on 1000 samples
    ## Epoch 1/10
    ## 1000/1000 - 0s - loss: 1.1775 - accuracy: 0.6750 - val_loss: 0.6874 - val_accuracy: 0.7980
    ## Epoch 2/10
    ## 1000/1000 - 0s - loss: 0.4144 - accuracy: 0.8810 - val_loss: 0.5366 - val_accuracy: 0.8320
    ## Epoch 3/10
    ## 1000/1000 - 0s - loss: 0.2811 - accuracy: 0.9280 - val_loss: 0.4517 - val_accuracy: 0.8610
    ## Epoch 4/10
    ## 1000/1000 - 0s - loss: 0.2205 - accuracy: 0.9430 - val_loss: 0.4692 - val_accuracy: 0.8500
    ## Epoch 5/10
    ## 1000/1000 - 0s - loss: 0.1520 - accuracy: 0.9690 - val_loss: 0.4084 - val_accuracy: 0.8660
    ## Epoch 6/10
    ## 1000/1000 - 0s - loss: 0.1147 - accuracy: 0.9780 - val_loss: 0.3946 - val_accuracy: 0.8680
    ## Epoch 7/10
    ## 1000/1000 - 0s - loss: 0.0831 - accuracy: 0.9870 - val_loss: 0.4008 - val_accuracy: 0.8710
    ## Epoch 8/10
    ## 1000/1000 - 0s - loss: 0.0607 - accuracy: 0.9970 - val_loss: 0.4056 - val_accuracy: 0.8640
    ## Epoch 9/10
    ## 1000/1000 - 0s - loss: 0.0510 - accuracy: 0.9970 - val_loss: 0.4031 - val_accuracy: 0.8720
    ## Epoch 10/10
    ## 1000/1000 - 0s - loss: 0.0465 - accuracy: 0.9960 - val_loss: 0.3923 - val_accuracy: 0.8710

    Inspect the files that were created:

    list.files(dirname(checkpoint_path))
    ## [1] "checkpoint"                  "cp.ckpt.data-00000-of-00001"
    ## [3] "cp.ckpt.index"

    Create a new, untrained model. When restoring a model from only weights, you must have a model with the same architecture as the original model. Since it’s the same model architecture, we can share weights despite that it’s a different instance of the model.

    Now rebuild a fresh, untrained model, and evaluate it on the test set. An untrained model will perform at chance levels (~10% accuracy):

    fresh_model <- create_model()
    fresh_model %>% evaluate(test_images, test_labels, verbose = 0)
    ## $loss
    ## [1] 2.321936
    ## 
    ## $accuracy
    ## [1] 0.126

    Then load the weights from the latest checkpoint (epoch 10), and re-evaluate:

    fresh_model %>% load_model_weights_tf(filepath = checkpoint_path)
    fresh_model %>% evaluate(test_images, test_labels, verbose = 0)
    ## $loss
    ## [1] 0.3923183
    ## 
    ## $accuracy
    ## [1] 0.871

    Checkpoint callback options

    Alternatively, you can decide to save only the best model, where best by default is defined as validation loss. See the documentation for callback_model_checkpoint for further information.

    ## Train on 1000 samples, validate on 1000 samples
    ## Epoch 1/10
    ## 
    ## Epoch 00001: val_loss improved from inf to 0.72178, saving model to checkpoints/cp.ckpt
    ## 1000/1000 - 0s - loss: 1.1691 - accuracy: 0.6620 - val_loss: 0.7218 - val_accuracy: 0.7760
    ## Epoch 2/10
    ## 
    ## Epoch 00002: val_loss improved from 0.72178 to 0.56689, saving model to checkpoints/cp.ckpt
    ## 1000/1000 - 0s - loss: 0.4227 - accuracy: 0.8850 - val_loss: 0.5669 - val_accuracy: 0.8110
    ## Epoch 3/10
    ## 
    ## Epoch 00003: val_loss improved from 0.56689 to 0.51581, saving model to checkpoints/cp.ckpt
    ## 1000/1000 - 0s - loss: 0.3018 - accuracy: 0.9160 - val_loss: 0.5158 - val_accuracy: 0.8380
    ## Epoch 4/10
    ## 
    ## Epoch 00004: val_loss improved from 0.51581 to 0.44739, saving model to checkpoints/cp.ckpt
    ## 1000/1000 - 0s - loss: 0.2120 - accuracy: 0.9480 - val_loss: 0.4474 - val_accuracy: 0.8540
    ## Epoch 5/10
    ## 
    ## Epoch 00005: val_loss did not improve from 0.44739
    ## 1000/1000 - 0s - loss: 0.1519 - accuracy: 0.9700 - val_loss: 0.4602 - val_accuracy: 0.8510
    ## Epoch 6/10
    ## 
    ## Epoch 00006: val_loss improved from 0.44739 to 0.42596, saving model to checkpoints/cp.ckpt
    ## 1000/1000 - 0s - loss: 0.1257 - accuracy: 0.9750 - val_loss: 0.4260 - val_accuracy: 0.8630
    ## Epoch 7/10
    ## 
    ## Epoch 00007: val_loss improved from 0.42596 to 0.40990, saving model to checkpoints/cp.ckpt
    ## 1000/1000 - 0s - loss: 0.0866 - accuracy: 0.9850 - val_loss: 0.4099 - val_accuracy: 0.8610
    ## Epoch 8/10
    ## 
    ## Epoch 00008: val_loss did not improve from 0.40990
    ## 1000/1000 - 0s - loss: 0.0688 - accuracy: 0.9930 - val_loss: 0.4210 - val_accuracy: 0.8560
    ## Epoch 9/10
    ## 
    ## Epoch 00009: val_loss did not improve from 0.40990
    ## 1000/1000 - 0s - loss: 0.0517 - accuracy: 0.9970 - val_loss: 0.4326 - val_accuracy: 0.8640
    ## Epoch 10/10
    ## 
    ## Epoch 00010: val_loss did not improve from 0.40990
    ## 1000/1000 - 0s - loss: 0.0386 - accuracy: 1.0000 - val_loss: 0.4521 - val_accuracy: 0.8510
    list.files(dirname(checkpoint_path))
    ## [1] "checkpoint"                  "cp.ckpt.data-00000-of-00001"
    ## [3] "cp.ckpt.index"

    What are these files?

    The above code stores the weights to a collection of checkpoint-formatted files that contain only the trained weights in a binary format. Checkpoints contain:

    • One or more shards that contain your model’s weights.
    • An index file that indicates which weights are stored in a which shard.

    If you are only training a model on a single machine, you’ll have one shard with the suffix: .data-00000-of-00001

    Manually save the weights

    You saw how to load the weights into a model. Manually saving them is just as simple with the save_model_weights_tf function.

    ## $loss
    ## [1] 0.4520541
    ## 
    ## $accuracy
    ## [1] 0.851