<- ... # Get model (Sequential, Functional Model, or Model subclass)
model save_model_tf("path/to/location")
Serialization and saving
Introduction
A Keras model consists of multiple components:
- The architecture, or configuration, which specifies what layers the model contain, and how they’re connected.
- A set of weights values (the “state of the model”).
- An optimizer (defined by compiling the model).
- A set of losses and metrics (defined by compiling the model or calling
add_loss()
oradd_metric()
).
The Keras API makes it possible to save all of these pieces to disk at once, or to only selectively save some of them:
- Saving everything into a single archive in the TensorFlow SavedModel format (or in the older Keras H5 format). This is the standard practice.
- Saving the architecture / configuration only, typically as a JSON file.
- Saving the weights values only. This is generally used when training the model.
Let’s take a look at each of these options. When would you use one or the other, and how do they work?
How to save and load a model
If you only have 10 seconds to read this guide, here’s what you need to know.
Saving a Keras model:
Loading the model back:
library(keras)
<- load_model_tf("path/to/location") model
Now, let’s look at the details.
Setup
library(tensorflow)
library(keras)
Whole-model saving & loading
You can save an entire model to a single artifact. It will include:
- The model’s architecture/config
- The model’s weight values (which were learned during training)
- The model’s compilation information (if
compile()
was called) - The optimizer and its state, if any (this enables you to restart training where you left)
APIs
model$save()
orsave_model_tf()
load_model_tf()
There are two formats you can use to save an entire model to disk: the TensorFlow SavedModel format, and the older Keras H5 format. The recommended format is SavedModel. It is the default when you use model$save()
.
You can switch to the H5 format by:
- Passing
save_format = 'h5'
tosave_model_hdf5()
. - Passing a filename that ends in
.h5
or.keras
to$save()
.
SavedModel format
SavedModel is the more comprehensive save format that saves the model architecture, weights, and the traced Tensorflow subgraphs of the call functions. This enables Keras to restore both built-in layers as well as custom objects.
Example:
<- function() {
get_model # Create a simple model.
<- layer_input(shape = shape(32))
inputs <- layer_dense(inputs, 1)
outputs <- keras_model(inputs, outputs)
model %>% compile(optimizer = "adam", loss = "mean_squared_error")
model
model
}<- get_model()
model # Train the model.
<- array(runif(128*32), dim = c(128, 32))
test_input <- array(runif(128), dim = c(128, 1))
test_target %>% fit(test_input, test_target)
model # Calling `save('my_model')` creates a SavedModel folder `my_model`.
save_model_tf(model, "my_model")
# It can be used to reconstruct the model identically.
<- load_model_tf("my_model")
reconstructed_model # Let's check:
all.equal(
predict(model, test_input),
predict(reconstructed_model, test_input)
)
[1] TRUE
# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
%>% fit(test_input, test_target) reconstructed_model
What the SavedModel contains
Calling save_model_tf(model, 'my_model')
creates a folder named my_model
, containing the following:
ls my_model
assets
checkpoint
fingerprint.pb
keras_metadata.pb
saved_model.pb
variables
The model architecture, and training configuration (including the optimizer, losses, and metrics) are stored in saved_model.pb
. The weights are saved in the variables/
directory.
For detailed information on the SavedModel format, see the SavedModel guide (The SavedModel format on disk).
How SavedModel handles custom objects
When saving the model and its layers, the SavedModel format stores the class name, call function, losses, and weights (and the config, if implemented). The call function defines the computation graph of the model/layer.
In the absence of the model/layer config, the call function is used to create a model that exists like the original model which can be trained, evaluated, and used for inference.
Nevertheless, it is always a good practice to define the get_config
and from_config
methods when writing a custom model or layer class. This allows you to easily update the computation later if needed. See the section about Custom objects for more information.
The default from_config
definition in R is something similar just calls the initialize
method with the config lsit using do.call
. That way you don’t need to implement a from_config
method unless get_config()
dictionary names don’t match the initialize arguments.
Example:
<- new_model_class(
custom_model "custom_model",
initialize = function(hidden_units) {
super()$`__init__`()
$hidden_units <- hidden_units
self$dense_layers <- lapply(hidden_units, function(x) layer_dense(units = x))
self
},call = function(inputs) {
<- inputs
x for (layer in self$dense_layers) {
<- layer(x)
x
}
x
},get_config = function() {
list(hidden_units = self$hidden_units)
}
)<- custom_model(c(16, 16, 10))
model # Build the model by calling it
<- tf$random$uniform(shape(1, 5))
input_arr <- model(input_arr)
outputs save_model_tf(model, "my_model")
# Option 1: Load with the custom_object argument.
<- load_model_tf(
loaded_1 "my_model", custom_objects = list("custom_model" = custom_model)
)# Option 2: Load without the CustomModel class.
# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
rm(custom_model); gc();
used (Mb) gc trigger (Mb) max used (Mb)
Ncells 2065304 110.3 4222005 225.5 2548341 136.1
Vcells 3633832 27.8 8388608 64.0 6282830 48.0
<- load_model_tf("my_model")
loaded_2 all.equal(predict(loaded_1, input_arr), as.array(outputs))
[1] TRUE
all.equal(predict(loaded_2, input_arr), as.array(outputs))
[1] TRUE
The first loaded model is loaded using the config and custom_model
class. The second model is loaded by dynamically creating the model class that acts like the original model.
Configuring the SavedModel
New in TensoFlow 2.4
The argument save_traces
has been added to model$save
, which allows you to toggle SavedModel function tracing. Functions are saved to allow the Keras to re-load custom objects without the original class definitons, so when save_traces = FALSE
, all custom objects must have defined get_config
/from_config
methods. When loading, the custom objects must be passed to the custom_objects
argument. save_traces = FALSE
reduces the disk space used by the SavedModel and saving time.
Keras H5 format
Keras also supports saving a single HDF5 file containing the model’s architecture, weights values, and compile()
information. It is a light-weight alternative to SavedModel.
Example:
<- get_model()
model # Train the model.
<- array(runif(128*32), dim = c(128, 32))
test_input <- array(runif(128), dim = c(128, 1))
test_target %>% fit(test_input, test_target)
model # Calling `save_model_hdf5('my_model.h5')` creates a h5 file `my_model.h5`.
save_model_hdf5(model, "my_h5_model.h5")
# It can be used to reconstruct the model identically.
<- load_model_hdf5("my_h5_model.h5")
reconstructed_model # Let's check:
all.equal(
predict(model, test_input),
predict(reconstructed_model, test_input)
)
[1] TRUE
# The reconstructed model is already compiled and has retained the optimizer
# state, so training can resume:
%>% fit(test_input, test_target) reconstructed_model
Format Limitations
Keras SavedModel format limitations:
The tracing done by SavedModel to produce the graphs of the layer call functions allows SavedModel be more portable than H5, but it comes with drawbacks.
- Can be slower and bulkier than H5.
- Cannot serialize the ops generated from the mask argument (i$e. if a layer is called with
layer(..., mask = mask_value)
, the mask argument is not saved to SavedModel). - Does not save the overridden
train_step()
in subclassed models.
Custom objects that use masks or have a custom training loop can still be saved and loaded from SavedModel, except they must override get_config()
/from_config()
, and the classes must be passed to the custom_objects
argument when loading.
H5 limitations:
- External losses & metrics added via
model$add_loss()
&model$add_metric()
are not saved (unlike SavedModel). If you have such losses & metrics on your model and you want to resume training, you need to add these losses back yourself after loading the model. Note that this does not apply to losses/metrics created inside layers viaself$add_loss()
&self$add_metric()
. As long as the layer gets loaded, these losses & metrics are kept, since they are part of thecall
method of the layer. - The computation graph of custom objects such as custom layers is not included in the saved file. At loading time, Keras will need access to the Python classes/functions of these objects in order to reconstruct the model. See Custom objects.
- Does not support preprocessing layers.
Saving the architecture
The model’s configuration (or architecture) specifies what layers the model contains, and how these layers are connected*. If you have the configuration of a model, then the model can be created with a freshly initialized state for the weights and no compilation information.
*Note this only applies to models defined using the functional or Sequential apis not subclassed models.
Configuration of a Sequential model or Functional API model
These types of models are explicit graphs of layers: their configuration is always available in a structured form.
APIs
get_config()
andfrom_config()
model_to_json()
andmodel_from_json()
get_config()
and from_config()
Calling config = model$get_config()
will return a Python dict containing the configuration of the model. The same model can then be reconstructed via Sequential$from_config(config)
(for a Sequential
model) or Model$from_config(config)
(for a Functional API model).
The same workflow also works for any serializable layer.
Layer example:
<- layer_dense(units = 3, activation = "relu")
layer <- get_config(layer)
layer_config <- from_config(config = layer_config) new_layer
Sequential model example:
<- keras_model_sequential(list(
model layer_input(shape = 32),
layer_dense(units = 1)
))<- get_config(model)
config <- from_config(config) new_model
Functional model example:
<- layer_input(shape = 32)
inputs <- layer_dense(inputs, 1)
outputs <- keras_model(inputs, outputs)
model <- get_config(model)
config <- from_config(config) new_model
model_to_json()
and model_from_json()
This is similar to get_config
/ from_config
, except it turns the model into a JSON string, which can then be loaded without the original model class. It is also specific to models, it isn’t meant for layers.
Example:
<- keras_model_sequential(list(
model layer_input(shape = 32),
layer_dense(units = 1)
))<- model_to_json(model)
json_config <- model_from_json(json_config) new_model
Custom objects
Models and layers
The architecture of subclassed models and layers are defined in the methods initialize
and call
. They are considered R bytecode, which cannot be serialized into a JSON-compatible config – you could try serializing the bytecode (e.g. via saveRDS
), but it’s completely unsafe and means your model cannot be loaded on a different system.
In order to save/load a model with custom-defined layers, or a subclassed model, you should overwrite the get_config
and optionally from_config
methods. Additionally, you should use register the custom object so that Keras is aware of it.
Custom functions
Custom-defined functions (e.g. activation loss or initialization) do not need a get_config
method. The function name is sufficient for loading as long as it is registered as a custom object.
Loading the TensorFlow graph only
It’s possible to load the TensorFlow graph generated by the Keras. If you do so, you won’t need to provide any custom_objects
. You can do so like this:
save_model_tf(model, "my_model")
<- tf$saved_model$load("my_model")
tensorflow_graph <- as_tensor(array(runif(4*32), dim = c(4, 32)), "float32")
x <- tensorflow_graph(x)$numpy() predicted
Note that this method has several drawbacks: * For traceability reasons, you should always have access to the custom objects that were used. You wouldn’t want to put in production a model that you cannot re-create. * The object returned by tf$saved_model$load
isn’t a Keras model. So it’s not as easy to use. For example, you won’t have access to predict()
or fit()
Even if its use is discouraged, it can help you if you’re in a tight spot, for example, if you lost the code of your custom objects or have issues loading the model with load_model_tf()
.
You can find out more in the page about tf$saved_model$load
Defining the config methods
Specifications:
get_config
should return a JSON-serializable dictionary in order to be compatible with the Keras architecture - and model-saving APIs.from_config(config)
(classmethod
) should return a new layer or model object that is created from the config. The default implementation returnsdo.call(cls, config)
.
Example:
<- new_layer_class(
custom_layer "custom_layer",
initialize = function(a) {
$var <- tf$Variable(a, name = "var_a")
self
},call = function(inputs, training = FALSE) {
if(training) {
*self$var
inputselse {
}
inputs
}
},get_config = function() {
list("a" = as.array(self$var))
}
)
<- custom_layer(a = 5)
layer $var$assign(2) layer
<tf.Variable 'UnreadVariable' shape=() dtype=float32, numpy=2.0>
<- keras$layers$serialize(layer)
serialized_layer <- keras$layers$deserialize(
new_layer custom_objects = list("custom_layer" = custom_layer)
serialized_layer, )
Registering the custom object
Keras keeps a note of which class generated the config. From the example above, tf$keras$layers$serialize
generates a serialized form of the custom layer:
list(class_name = "custom_layer", config = list(a = 2))
Keras keeps a master list of all built-in layer, model, optimizer, and metric classes, which is used to find the correct class to call from_config
. If the class can’t be found, then an error is raised (Value Error: Unknown layer
). There are a few ways to register custom classes to this list: 1. Setting custom_objects
argument in the loading function. (see the example in section above “Defining the config methods”) 2. tf$keras$utils$custom_object_scope
or tf$keras$utils$CustomObjectScope
3. tf$keras$utils$register_keras_serializable
Custom layer and function example
<- new_layer_class(
custom_layer "custom_layer",
initialize = function(units = 32, ...) {
super()$`__init__`(...)
$units <- units
self
},build = function(input_shape) {
$w <- self$add_weight(
selfshape = shape(tail(input_shape, 1), self$units),
initializer = "random_normal",
trainable = TRUE
)$b <- self$add_weight(
selfshape = shape(self$units),
initializer = "random_normal",
trainable = TRUE
)
},call = function(inputs) {
$matmul(inputs, self$w) + self$b
tf
},get_config = function() {
<- super()$get_config()
config $units <- self$units
config
config
}
)
<- function(x) {
custom_activation $nn$tanh(x)^2
tf
}
# Make a model with the custom_layer and custom_activation
<- layer_input(shape = shape(32))
inputs <- custom_layer(inputs, 32)
x <- layer_activation(x, custom_activation)
outputs <- keras_model(inputs, outputs)
model
# Retrieve the config
<- get_config(model)
config
# At loading time, register the custom objects with a `custom_object_scope`:
<- list(
custom_objects "custom_layer" = custom_layer,
"python_function" = custom_activation
)
with(tf$keras$utils$custom_object_scope(custom_objects), {
<- keras$Model$from_config(config)
new_model })
In-memory model cloning
You can also do in-memory cloning of a model via tf$keras$models$clone_model()
. This is equivalent to getting the config then recreating the model from its config (so it does not preserve compilation information or layer weights values).
Example:
with(tf$keras$utils$custom_object_scope(custom_objects), {
<- clone_model(model)
new_model })
Saving & loading only the model’s weights values
You can choose to only save & load a model’s weights. This can be useful if: - You only need the model for inference: in this case you won’t need to restart training, so you don’t need the compilation information or optimizer state. - You are doing transfer learning: in this case you will be training a new model reusing the state of a prior model, so you don’t need the compilation information of the prior model.
APIs for in-memory weight transfer
Weights can be copied between different objects by using get_weights
and set_weights
:
get_weights()
: Returns a list of arrays.set_weights()
: Sets the model weights to the values in theweights
argument.
Examples below.
Transfering weights from one layer to another, in memory
<- function() {
create_layer <- layer_dense(units = 64, activation = "relu", name = "dense_2")
layer $build(shape(NULL, 784))
layer
layer
}
<- create_layer()
layer_1 <- create_layer()
layer_2
# Copy weights from layer 1 to layer 2
set_weights(layer_2, get_weights(layer_1))
Transfering weights from one model to another model with a compatible architecture, in memory
Create a simple functional model
<- layer_input(shape = 784, name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(10, name = "predictions")
<- keras_model(
functional_model inputs = inputs,
outputs = outputs,
name = "3_layer_mlp"
)
Define a subclassed model with the same architecture
<- new_model_class(
subclassed_model "subclassed_model",
initialize = function(output_dim, name = NULL) {
super()$`__init__`(name = name)
$output_dim <- output_dim
self$dense_1 <- layer_dense(units = 64, activation = "relu", name = "dense_1")
self$dense_2 <- layer_dense(units = 64, activation = "relu", name = "dense_2")
self$dense_3 <- layer_dense(units = output_dim, name = "predictions")
self
},call = function(inputs) {
%>%
inputs $dense_1() %>%
self$dense_2() %>%
self$dense_3()
self
},get_config = function() {
list(
output_dim = self$output_dim,
name = self$name
)
}
)
<- subclassed_model(output_dim = 10)
model
# Call the subclassed model once to create the weights.
model(tf$ones(shape(1, 784)))
tf.Tensor(
[[ 0.02261962 0.02064185 -0.01209853 -1.1393347 -0.18542328 -0.6146086
0.43669882 0.21682692 -1.1092497 1.8132143 ]], shape=(1, 10), dtype=float32)
# Copy weights from functional_model to subclassed_model.
set_weights(model, get_weights(functional_model))
length(functional_model$weights) == length(model$weights)
[1] TRUE
all.equal(get_weights(functional_model), get_weights(model))
[1] TRUE
The case of stateless layers
Because stateless layers do not change the order or number of weights, models can have compatible architectures even if there are extra/missing stateless layers.
<- layer_input(shape = shape(784), name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(10, name = "predictions")
<- keras_model(
functional_model inputs = inputs,
outputs = outputs,
name = "3_layer_mlp"
)
<- layer_input(shape = shape(784), name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
# Add a dropout layer, which does not contain any weights.
layer_dropout(0.5) %>%
layer_dense(10, name = "predictions")
<- keras_model(
functional_model_with_dropout inputs = inputs,
outputs = outputs,
name = "3_layer_mlp"
)
set_weights(functional_model_with_dropout, get_weights(functional_model))
APIs for saving weights to disk & loading them back
Weights can be saved to disk by calling model$save_weights
in the following formats:
- TensorFlow Checkpoint:
save_model_weights_tf()
- HDF5:
save_model_weights_hdf5()
Each format has its pros and cons which are detailed below.
TF Checkpoint format
Example:
# Runnable example
<- keras_model_sequential(input_shape = shape(784)) %>%
sequential_model layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(10, name = "predictions")
save_model_weights_tf(sequential_model, "ckpt")
load_model_weights_tf(sequential_model, "ckpt")
Format details
The TensorFlow Checkpoint format saves and restores the weights using object attribute names. For instance, consider the layer_dense
layer. The layer contains two weights: dense$kernel
and dense$bias
. When the layer is saved to the tf
format, the resulting checkpoint contains the keys "kernel"
and "bias"
and their corresponding weight values. For more information see “Loading mechanics” in the TF Checkpoint guide. Note that attribute/graph edge is named after the name used in parent object, not the name of the variable. Consider the custom_layer
in the example below. The variable custom_layer$var
is saved with "var"
as part of key, not "var_a"
.
<- new_layer_class(
custom_layer "custom_layer",
initialize = function(a) {
$var <- tf$Variable(a, name = "var_a")
self
}
)
<- custom_layer(a = 5)
layer <- tf$train$Checkpoint(layer = layer)$save("custom_layer")
layer_ckpt <- tf$train$load_checkpoint(layer_ckpt)
ckpt_reader $get_variable_to_dtype_map() ckpt_reader
$`save_counter/.ATTRIBUTES/VARIABLE_VALUE`
tf.int64
$`layer/var/.ATTRIBUTES/VARIABLE_VALUE`
tf.float32
$`_CHECKPOINTABLE_OBJECT_GRAPH`
tf.string
Transfer learning example
Essentially, as long as two models have the same architecture, they are able to share the same checkpoint.
Example:
<- layer_input(shape = shape(784), name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(10, name = "predictions")
<- keras_model(
functional_model inputs = inputs,
outputs = outputs,
name = "3_layer_mlp"
)
# Extract a portion of the functional model defined in the Setup section.
# The following lines produce a new model that excludes the final output
# layer of the functional model.
<- keras_model(
pretrained inputs = functional_model$inputs,
outputs = functional_model$layers[[4]]$input,
name = "pretrained_model"
)
# Randomly assign "trained" weights.
for (w in pretrained$weights) {
$assign(tf$random$normal(w$shape))
w
}
save_model_weights_tf(pretrained, "pretrained_ckpt")
summary(pretrained)
Model: "pretrained_model"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
digits (InputLayer) [(None, 784)] 0
dense_1 (Dense) (None, 64) 50240
dense_2 (Dense) (None, 64) 4160
============================================================================
Total params: 54,400
Trainable params: 54,400
Non-trainable params: 0
____________________________________________________________________________
# Assume this is a separate program where only 'pretrained_ckpt' exists.
# Create a new functional model with a different output dimension.
<- layer_input(shape = shape(784), name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(5, name = "predictions")
<- keras_model(inputs = inputs, outputs = outputs, name = "new_model")
model
# Load the weights from pretrained_ckpt into model.
load_model_weights_tf(model, "pretrained_ckpt")
# Check that all of the pretrained weights have been loaded.
all.equal(get_weights(pretrained), head(get_weights(model), 4))
[1] TRUE
summary(model)
Model: "new_model"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
digits (InputLayer) [(None, 784)] 0
dense_1 (Dense) (None, 64) 50240
dense_2 (Dense) (None, 64) 4160
predictions (Dense) (None, 5) 325
============================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
____________________________________________________________________________
# Example 2: Sequential model
# Recreate the pretrained model, and load the saved weights.
<- layer_input(shape = shape(784), name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2")
<- keras_model(
pretrained_model inputs = inputs,
outputs = outputs,
name = "pretrained"
)
# Sequential example:
<- keras_model_sequential() %>%
model pretrained_model() %>%
layer_dense(5, name = "predictions")
summary(model)
Model: "sequential_3"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
pretrained (Functional) (None, 64) 54400
predictions (Dense) (None, 5) 325
============================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
____________________________________________________________________________
load_model_weights_tf(pretrained_model, "pretrained_ckpt")
Warning! Calling model$load_weights('pretrained_ckpt')
won’t throw an error, but will not work as expected. If you inspect the weights, you’ll see that none of the weights will have loaded. pretrained_model$load_weights()
is the correct method to call.
It is generally recommended to stick to the same API for building models. If you switch between Sequential and Functional, or Functional and subclassed, etc., then always rebuild the pre-trained model and load the pre-trained weights to that model.
The next question is, how can weights be saved and loaded to different models if the model architectures are quite different? The solution is to use tf$train$Checkpoint
to save and restore the exact layers/variables.
Example:
# Create a subclassed model that essentially uses functional_model's first
# and last layers.
# First, save the weights of functional_model's first and last dense layers.
<- functional_model$layers[[2]]
first_dense <- functional_model$layers[[4]]
last_dense <- tf$train$Checkpoint(
ckpt_path dense = first_dense,
kernel = last_dense$kernel,
bias = last_dense$bias
$save("ckpt")
)
# Define the subclassed model.
<- new_model_class(
contrived_model "contrived_model",
initialize = function() {
super()$`__init__`()
$first_dense <- layer_dense(units = 64)
self$kernel <- self$add_weight("kernel", shape = shape(64, 10))
self$bias <- self$add_weight("bias", shape = shape(10))
self
},call = function(inputs) {
<- self$first_dense(inputs)
x $matmul(x, self$kernel) + self$bias
tf
}
)
<- contrived_model()
model # Call model on inputs to create the variables of the dense layer.
invisible(model(tf$ones(shape(1, 784))))
# Create a Checkpoint with the same structure as before, and load the weights.
$train$Checkpoint(
tfdense = model$first_dense,
kernel = model$kernel,
bias = model$bias
$restore(ckpt_path)$assert_consumed() )
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus object at 0x7f83bee4fe80>
HDF5 format
The HDF5 format contains weights grouped by layer names. The weights are lists ordered by concatenating the list of trainable weights to the list of non-trainable weights (same as layer$weights
).
Thus, a model can use a hdf5 checkpoint if it has the same layers and trainable statuses as saved in the checkpoint.
Example:
# Runnable example
<- keras_model_sequential(input_shape = 784) %>%
sequential_model layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(10, name = "predictions")
save_model_weights_hdf5(sequential_model, "weights.h5")
load_model_weights_hdf5(sequential_model, "weights.h5")
Note that changing layer$trainable
may result in a different layer$weights
ordering when the model contains nested layers.
<- new_layer_class(
nested_dense_layer "nested_dense_layer",
initialize = function(units, name = NULL) {
super()$`__init__`(name = name)
$dense_1 <- layer_dense(units = units, name = "dense_1")
self$dense_2 <- layer_dense(units = units, name = "dense_2")
self
},call = function(inputs) {
%>%
inputs $dense_1() %>%
self$dense_2()
self
}
)
<- keras_model_sequential(input_shape = 784) %>%
nested_model nested_dense_layer(units = 10, name = "nested")
<- lapply(nested_model$weights, function(x) x$name)
variable_names str(variable_names)
List of 4
$ : chr "nested/dense_1/kernel:0"
$ : chr "nested/dense_1/bias:0"
$ : chr "nested/dense_2/kernel:0"
$ : chr "nested/dense_2/bias:0"
print("\nChanging trainable status of one of the nested layers...")
[1] "\nChanging trainable status of one of the nested layers..."
<- nested_model %>%
layer get_layer("nested")
$dense_1$trainable <- FALSE
layer
<- lapply(nested_model$weights, function(x) x$name)
variable_names_2 str(variable_names_2)
List of 4
$ : chr "nested/dense_2/kernel:0"
$ : chr "nested/dense_2/bias:0"
$ : chr "nested/dense_1/kernel:0"
$ : chr "nested/dense_1/bias:0"
Transfer learning example
When loading pretrained weights from HDF5, it is recommended to load the weights into the original checkpointed model, and then extract the desired weights/layers into a new model.
Example:
<- function() {
create_functional_model <- layer_input(shape = shape(784), name = "digits")
inputs <- inputs %>%
outputs layer_dense(64, activation = "relu", name = "dense_1") %>%
layer_dense(64, activation = "relu", name = "dense_2") %>%
layer_dense(10, name = "predictions")
keras_model(inputs = inputs, outputs = outputs, name = "3_layer_mlp")
}
<- create_functional_model()
functional_model save_model_weights_hdf5(functional_model, "pretrained_weights.h5")
# In a separate program:
<- create_functional_model()
pretrained_model load_model_weights_hdf5(pretrained_model, "pretrained_weights.h5")
# Create a new model by extracting layers from the original model:
<- pretrained_model$layers[1:3]
extracted_layers <- c(extracted_layers, layer_dense(units = 5, name = "dense_3"))
extracted_layers <- keras_model_sequential(extracted_layers)
model model
Model: "sequential_6"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
dense_1 (Dense) (None, 64) 50240
dense_2 (Dense) (None, 64) 4160
dense_3 (Dense) (None, 5) 325
============================================================================
Total params: 54,725
Trainable params: 54,725
Non-trainable params: 0
____________________________________________________________________________
Environment Details
::tf_config() tensorflow
TensorFlow v2.11.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.2.1 (2022-06-23)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS
Matrix products: default
BLAS: /home/tomasz/opt/R-4.2.1/lib/R/lib/libRblas.so
LAPACK: /usr/lib/x86_64-linux-gnu/libmkl_intel_lp64.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] keras_2.9.0.9000 tensorflow_2.9.0.9000
loaded via a namespace (and not attached):
[1] Rcpp_1.0.9 pillar_1.8.1 compiler_4.2.1
[4] base64enc_0.1-3 tools_4.2.1 zeallot_0.1.0
[7] digest_0.6.31 jsonlite_1.8.4 evaluate_0.18
[10] lifecycle_1.0.3 tibble_3.1.8 lattice_0.20-45
[13] pkgconfig_2.0.3 png_0.1-8 rlang_1.0.6
[16] Matrix_1.5-3 cli_3.4.1 yaml_2.3.6
[19] xfun_0.35 fastmap_1.1.0 stringr_1.5.0
[22] knitr_1.41 generics_0.1.3 vctrs_0.5.1
[25] htmlwidgets_1.5.4 rprojroot_2.0.3 grid_4.2.1
[28] reticulate_1.26-9000 glue_1.6.2 here_1.0.1
[31] R6_2.5.1 fansi_1.0.3 rmarkdown_2.18
[34] magrittr_2.0.3 whisker_0.4.1 htmltools_0.5.4
[37] tfruns_1.5.1 utf8_1.2.2 stringi_1.7.8
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.3.0
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
cachetools==5.2.0
certifi==2022.12.7
charset-normalizer==2.1.1
decorator==5.1.1
dill==0.3.6
etils==0.9.0
executing==1.2.0
flatbuffers==22.12.6
gast==0.4.0
google-auth==2.15.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.57.0
grpcio==1.51.1
h5py==3.7.0
idna==3.4
importlib-resources==5.10.1
ipython==8.7.0
jedi==0.18.2
kaggle==1.5.12
keras==2.11.0
keras-tuner==1.1.3
kt-legacy==1.0.4
libclang==14.0.6
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib-inline==0.1.6
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.2
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
promise==2.3
prompt-toolkit==3.0.36
protobuf==3.19.6
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pydot==1.4.2
Pygments==2.13.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-slugify==7.0.0
pytz==2022.6
PyYAML==6.0
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.9.3
six==1.16.0
stack-data==0.6.2
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.12.0
termcolor==2.1.1
text-unidecode==1.3
toml==0.10.2
tqdm==4.64.1
traitlets==5.7.1
typing_extensions==4.4.0
urllib3==1.26.13
wcwidth==0.2.5
Werkzeug==2.2.2
wrapt==1.14.1
zipp==3.11.0
TF Devices:
- PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
- PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
CPU cores: 12
Date rendered: 2022-12-16
Page render time: 10 seconds