R/freeze.R

freeze_weights

Freeze and unfreeze weights

Description

Freeze weights in a model or layer so that they are no longer trainable.

Usage

 
freeze_weights(object, from = NULL, to = NULL, which = NULL) 
 
unfreeze_weights(object, from = NULL, to = NULL, which = NULL) 

Arguments

Arguments Description
object Keras model or layer object
from Layer instance, layer name, or layer index within model
to Layer instance, layer name, or layer index within model
which layer names, integer positions, layers, logical vector (of length(object$layers)), or a function returning a logical vector.

Note

The from and to layer arguments are both inclusive. When applied to a model, the freeze or unfreeze is a global operation over all layers in the model (i.e. layers not within the specified range will be set to the opposite value, e.g. unfrozen for a call to freeze). Models must be compiled again after weights are frozen or unfrozen.

Examples

library(keras)
 
conv_base <- application_vgg16( 
  weights = "imagenet", 
  include_top = FALSE, 
  input_shape = c(150, 150, 3) 
) 
 
# freeze it's weights 
freeze_weights(conv_base) 
 
conv_base 
 
# create a composite model that includes the base + more layers 
model <- keras_model_sequential() %>% 
  conv_base() %>% 
  layer_flatten() %>% 
  layer_dense(units = 256, activation = "relu") %>% 
  layer_dense(units = 1, activation = "sigmoid") 
 
# compile 
model %>% compile( 
  loss = "binary_crossentropy", 
  optimizer = optimizer_rmsprop(lr = 2e-5), 
  metrics = c("accuracy") 
) 
 
model 
print(model, expand_nested = TRUE) 
 
 
 
# unfreeze weights from "block5_conv1" on 
unfreeze_weights(conv_base, from = "block5_conv1") 
 
# compile again since we froze or unfroze weights 
model %>% compile( 
  loss = "binary_crossentropy", 
  optimizer = optimizer_rmsprop(lr = 2e-5), 
  metrics = c("accuracy") 
) 
 
conv_base 
print(model, expand_nested = TRUE) 
 
# freeze only the last 5 layers 
freeze_weights(conv_base, from = -5) 
conv_base 
# equivalently, also freeze only the last 5 layers 
unfreeze_weights(conv_base, to = -6) 
conv_base 
 
# Freeze only layers of a certain type, e.g, BatchNorm layers 
batch_norm_layer_class_name <- class(layer_batch_normalization())[1] 
is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name) 
 
model <- application_efficientnet_b0() 
freeze_weights(model, which = is_batch_norm_layer) 
model 
# equivalent to: 
for(layer in model$layers) { 
  if(is_batch_norm_layer(layer)) 
    layer$trainable <- FALSE 
  else 
    layer$trainable <- TRUE 
}