Making custom layer and model objects.

Complete guide to writing custom layer and model classes.

Setup

library(magrittr)
library(tensorflow)
library(tfdatasets)
library(keras)

tf_version()
[1] '2.13'

The Layer class: a combination of state (weights) and some computation

One of the central abstractions in Keras is the Layer class. A layer encapsulates both a state (the layer’s “weights”) and a transformation from inputs to outputs (a “call”, the layer’s forward pass).

Here’s a densely-connected layer. It has a state: the variables w and b.

Linear(keras$layers$Layer) %py_class% {
  initialize <- function(units = 32, input_dim = 32) {
    super$initialize()
    w_init <- tf$random_normal_initializer()
    self$w <- tf$Variable(
      initial_value = w_init(
        shape = shape(input_dim, units),
        dtype = "float32"
      ),
      trainable = TRUE
    )
    b_init <- tf$zeros_initializer()
    self$b <- tf$Variable(
      initial_value = b_init(shape = shape(units), dtype = "float32"),
      trainable = TRUE
    )
  }

  call <- function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }
}

You would use a layer by calling it on some tensor input(s), much like a regular function.

x <- tf$ones(shape(2, 2))
linear_layer <- Linear(4, 2)
y <- linear_layer(x)
print(y)
tf.Tensor(
[[-0.1341999  -0.03197089  0.10784246  0.06014311]
 [-0.1341999  -0.03197089  0.10784246  0.06014311]], shape=(2, 4), dtype=float32)

Linear behaves similarly to a layer present in the Python interface to keras (e.g., keras$layers$Dense).

However, one additional step is needed to make it behave like the builtin layers present in the keras R package (e.g., layer_dense()).

Keras layers in R are designed to compose nicely with the pipe operator (%>%), so that the layer instance is conveniently created on demand when an existing model or tensor is piped in. In order to make a custom layer similarly compose nicely with the pipe, you can call create_layer_wrapper() on the layer class constructor.

layer_linear <- create_layer_wrapper(Linear)

Now layer_linear is a layer constructor that composes nicely with %>%, just like the built-in layers:

model <- keras_model_sequential() %>%
  layer_linear(4, 2)

model(k_ones(c(2, 2)))
tf.Tensor(
[[ 0.06188128 -0.00513574 -0.02163585  0.03163266]
 [ 0.06188128 -0.00513574 -0.02163585  0.03163266]], shape=(2, 4), dtype=float32)
model
Model: "sequential"
____________________________________________________________________________
 Layer (type)                     Output Shape                  Param #     
============================================================================
 linear_1 (Linear)                (2, 4)                        12          
============================================================================
Total params: 12 (48.00 Byte)
Trainable params: 12 (48.00 Byte)
Non-trainable params: 0 (0.00 Byte)
____________________________________________________________________________

Because the pattern above is so common, there is a convenience function that combines the steps of subclassing keras$layers$Layer and calling create_layer_wrapper on the output: the new_layer_class() function. The layer_linear() defined below is identical to the layer_linear() defined above.

layer_linear <- new_layer_class(
  "Linear",
  initialize =  function(units = 32, input_dim = 32) {
    super$initialize()
    w_init <- tf$random_normal_initializer()
    self$w <- tf$Variable(initial_value = w_init(shape = shape(input_dim, units),
                                                 dtype = "float32"),
                          trainable = TRUE)
    b_init <- tf$zeros_initializer()
    self$b <- tf$Variable(initial_value = b_init(shape = shape(units),
                                                 dtype = "float32"),
                          trainable = TRUE)
  },

  call = function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }
)

For the remainder of this vignette we’ll be using the %py_class% constructor. However, in your own code feel free to use create_layer_wrapper() or new_layer_class() if you prefer.

Note that the weights w and b are automatically tracked by the layer upon being set as layer attributes:

stopifnot(all.equal(
  linear_layer$weights,
  list(linear_layer$w, linear_layer$b)
))

You also have access to a quicker shortcut for adding a weight to a layer: the add_weight() method:

Linear(keras$layers$Layer) %py_class% {
  initialize <- function(units = 32, input_dim = 32) {
    super$initialize()
    w_init <- tf$random_normal_initializer()
    self$w <- self$add_weight(
      shape = shape(input_dim, units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(units),
      initializer = "zeros",
      trainable = TRUE
    )
  }

  call <- function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }
}

x <- tf$ones(shape(2, 2))
linear_layer <- Linear(4, 2)
y <- linear_layer(x)
print(y)
tf.Tensor(
[[-0.09172767  0.050132   -0.10547464  0.05247786]
 [-0.09172767  0.050132   -0.10547464  0.05247786]], shape=(2, 4), dtype=float32)

Layers can have non-trainable weights

Besides trainable weights, you can add non-trainable weights to a layer as well. Such weights are meant not to be taken into account during backpropagation, when you are training the layer.

Here’s how to add and use a non-trainable weight:

ComputeSum(keras$layers$Layer) %py_class% {
  initialize <- function(input_dim) {
    super$initialize()
    self$total <- tf$Variable(
      initial_value = tf$zeros(shape(input_dim)),
      trainable = FALSE
    )
  }

  call <- function(inputs) {
    self$total$assign_add(tf$reduce_sum(inputs, axis = 0L))
    self$total
  }
}

x <- tf$ones(shape(2, 2))
my_sum <- ComputeSum(2)
y <- my_sum(x)
print(as.numeric(y))
[1] 2 2
y <- my_sum(x)
print(as.numeric(y))
[1] 4 4

It’s part of layer$weights, but it gets categorized as a non-trainable weight:

cat("weights:", length(my_sum$weights), "\n")
weights: 1 
cat("non-trainable weights:", length(my_sum$non_trainable_weights), "\n")
non-trainable weights: 1 
# It's not included in the trainable weights:
cat("trainable_weights:", my_sum$trainable_weights, "\n")
trainable_weights:  

Best practice: deferring weight creation until the shape of the inputs is known

Our Linear layer above took an input_dimargument that was used to compute the shape of the weights w and b in initialize():

Linear(keras$layers$Layer) %py_class% {
  initialize <- function(units = 32, input_dim = 32) {
    super$initialize()
    self$w <- self$add_weight(
      shape = shape(input_dim, units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(units),
      initializer = "zeros",
      trainable = TRUE
    )
  }

  call <- function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }
}

In many cases, you may not know in advance the size of your inputs, and you would like to lazily create weights when that value becomes known, some time after instantiating the layer.

In the Keras API, we recommend creating layer weights in the build(self, inputs_shape) method of your layer. Like this:

Linear(keras$layers$Layer) %py_class% {
  initialize <- function(units = 32) {
    super$initialize()
    self$units <- units
  }

  build <- function(input_shape) {
    self$w <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
  }

  call <- function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }
}

The build() method of your layer will automatically run the first time your layer instance is called. You now have a layer that can handle an arbitrary number of input features:

# At instantiation, we don't know on what inputs this is going to get called
linear_layer <- Linear(32)

# The layer's weights are created dynamically the first time the layer is called
y <- linear_layer(x)

Implementing build() separately as shown above nicely separates creating weights only once from using weights in every call. However, for some advanced custom layers, it can become impractical to separate the state creation and computation. Layer implementers are allowed to defer weight creation to the first call(), but need to take care that later calls use the same weights. In addition, since call() is likely to be executed for the first time inside a tf_function(), any variable creation that takes place in call() should be wrapped in a tf$init_scope().

Layers are recursively composable

If you assign a Layer instance as an attribute of another Layer, the outer layer will start tracking the weights created by the inner layer.

We recommend creating such sublayers in the initialize() method and leave it to the first call() to trigger building their weights.

# Let's assume we are reusing the Linear class
# with a `build` method that we defined above.
MLPBlock(keras$layers$Layer) %py_class% {
  initialize <- function() {
    super$initialize()
    self$linear_1 <- Linear(32)
    self$linear_2 <- Linear(32)
    self$linear_3 <- Linear(1)
  }

  call <- function(inputs) {
    x <- self$linear_1(inputs)
    x <- tf$nn$relu(x)
    x <- self$linear_2(x)
    x <- tf$nn$relu(x)
    self$linear_3(x)
  }
}

mlp <- MLPBlock()
y <- mlp(tf$ones(shape = shape(3, 64))) # The first call to the `mlp` will create the weights
cat("weights:", length(mlp$weights), "\n")
weights: 6 
cat("trainable weights:", length(mlp$trainable_weights), "\n")
trainable weights: 6 

The add_loss() method

When writing the call() method of a layer, you can create loss tensors that you will want to use later, when writing your training loop. This is doable by calling self$add_loss(value):

# A layer that creates an activity regularization loss
ActivityRegularizationLayer(keras$layers$Layer) %py_class% {
  initialize <- function(rate = 1e-2) {
    super$initialize()
    self$rate <- rate
  }

  call <- function(inputs) {
    self$add_loss(self$rate * tf$reduce_sum(inputs))
    inputs
  }
}

These losses (including those created by any inner layer) can be retrieved via layer$losses. This property is reset at the start of every call() to the top-level layer, so that layer$losses always contains the loss values created during the last forward pass.

OuterLayer(keras$layers$Layer) %py_class% {
  initialize <- function() {
    super$initialize()
    self$activity_reg <- ActivityRegularizationLayer(1e-2)
  }
  call <- function(inputs) {
    self$activity_reg(inputs)
  }
}

layer <- OuterLayer()
stopifnot(length(layer$losses) == 0) # No losses yet since the layer has never been called

layer(tf$zeros(shape(1, 1))) |> invisible()
stopifnot(length(layer$losses) == 1) # We created one loss value

# `layer$losses` gets reset at the start of each call()
layer(tf$zeros(shape(1, 1))) |> invisible()
stopifnot(length(layer$losses) == 1) # This is the loss created during the call above

In addition, the loss property also contains regularization losses created for the weights of any inner layer:

OuterLayerWithKernelRegularizer(keras$layers$Layer) %py_class% {
  initialize <- function() {
    super$initialize()
    self$dense <- layer_dense(units = 32, kernel_regularizer = regularizer_l2(1e-3))
  }
  call <- function(inputs) {
    self$dense(inputs)
  }
}

layer <- OuterLayerWithKernelRegularizer()
Warning in keras$regularizers$l2(l = l): partial argument match of 'l' to
'l2'
layer(tf$zeros(shape(1, 1))) |> invisible()

# This is `1e-3 * sum(layer$dense$kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer$losses)
[[1]]
tf.Tensor(0.0014562383, shape=(), dtype=float32)

These losses are meant to be taken into account when writing training loops, like this:

# Instantiate an optimizer.
optimizer <- optimizer_sgd(learning_rate = 1e-3)
loss_fn <- loss_sparse_categorical_crossentropy(from_logits = TRUE)

# Iterate over the batches of a dataset.
dataset_iterator <- reticulate::as_iterator(train_dataset)
while(!is.null(batch <- iter_next(dataset_iterator))) {
  c(x_batch_train, y_batch_train) %<-% batch
  with(tf$GradientTape() %as% tape, {
    logits <- layer(x_batch_train) # Logits for this minibatch
    # Loss value for this minibatch
    loss_value <- loss_fn(y_batch_train, logits)
    # Add extra losses created during this forward pass:
    loss_value <- loss_value + sum(model$losses)
  })
  grads <- tape$gradient(loss_value, model$trainable_weights)
  optimizer$apply_gradients(
    purrr::transpose(list(grads, model$trainable_weights)))
}

For a detailed guide about writing training loops, see the guide to writing a training loop from scratch.

These losses also work seamlessly with fit() (they get automatically summed and added to the main loss, if any):

input <- layer_input(shape(3))
output <- input %>% layer_activity_regularization()
# output <- ActivityRegularizationLayer()(input)
model <- keras_model(input, output)

# If there is a loss passed in `compile`, the regularization
# losses get added to it
model %>% compile(optimizer = "adam", loss = "mse")
model %>% fit(k_random_uniform(c(2, 3)),
  k_random_uniform(c(2, 3)),
  epochs = 1, verbose = FALSE
)

# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model %>% compile(optimizer = "adam")
model %>% fit(k_random_uniform(c(2, 3)),
  k_random_uniform(c(2, 3)),
  epochs = 1, verbose = FALSE
)

The add_metric() method

Similarly to add_loss(), layers also have an add_metric() method for tracking the moving average of a quantity during training.

Consider the following layer: a “logistic endpoint” layer. It takes as inputs predictions and targets, it computes a loss which it tracks via add_loss(), and it computes an accuracy scalar, which it tracks via add_metric().

LogisticEndpoint(keras$layers$Layer) %py_class% {
  initialize <- function(name = NULL) {
    super$initialize(name = name)
    self$loss_fn <- loss_binary_crossentropy(from_logits = TRUE)
    self$accuracy_fn <- metric_binary_accuracy()
  }

  call <- function(targets, logits, sample_weights = NULL) {
    # Compute the training-time loss value and add it
    # to the layer using `self$add_loss()`.
    loss <- self$loss_fn(targets, logits, sample_weights)
    self$add_loss(loss)

    # Log accuracy as a metric and add it
    # to the layer using `self.add_metric()`.
    acc <- self$accuracy_fn(targets, logits, sample_weights)
    self$add_metric(acc, name = "accuracy")

    # Return the inference-time prediction tensor (for `.predict()`).
    tf$nn$softmax(logits)
  }
}

Metrics tracked in this way are accessible via layer$metrics:

layer <- LogisticEndpoint()

targets <- tf$ones(shape(2, 2))
logits <- tf$ones(shape(2, 2))
y <- layer(targets, logits)

cat("layer$metrics: ")
layer$metrics: 
str(layer$metrics)
List of 1
 $ :BinaryAccuracy(name=binary_accuracy,dtype=float32,threshold=0.5)
cat("current accuracy value:", as.numeric(layer$metrics[[1]]$result()), "\n")
current accuracy value: 1 

Just like for add_loss(), these metrics are tracked by fit():

inputs <- layer_input(shape(3), name = "inputs")
targets <- layer_input(shape(10), name = "targets")
logits <- inputs %>% layer_dense(10)
predictions <- LogisticEndpoint(name = "predictions")(logits, targets)

model <- keras_model(inputs = list(inputs, targets), outputs = predictions)
model %>% compile(optimizer = "adam")

data <- list(
  inputs = k_random_uniform(c(3, 3)),
  targets = k_random_uniform(c(3, 10))
)

model %>% fit(data, epochs = 1, verbose = FALSE)

You can optionally enable serialization on your layers

If you need your custom layers to be serializable as part of a Functional model, you can optionally implement a get_config() method:

Linear(keras$layers$Layer) %py_class% {
  initialize <- function(units = 32) {
    super$initialize()
    self$units <- units
  }

  build <- function(input_shape) {
    self$w <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
  }

  call <- function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }

  get_config <- function() {
    list(units = self$units)
  }
}


# Now you can recreate the layer from its config:
layer <- Linear(64)
config <- layer$get_config()
print(config)
$units
[1] 64
new_layer <- Linear$from_config(config)

Note that the initialize() method of the base Layer class takes some additional named arguments, in particular a name and a dtype. It’s good practice to pass these arguments to the parent class in initialize() and to include them in the layer config:

Linear(keras$layers$Layer) %py_class% {
  initialize <- function(units = 32, ...) {
    super$initialize(...)
    self$units <- units
  }

  build <- function(input_shape) {
    self$w <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
  }

  call <- function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }

  get_config <- function() {
    config <- super$get_config()
    config$units <- self$units
    config
  }
}


layer <- Linear(64)
config <- layer$get_config()
str(config)
List of 4
 $ name     : chr "linear_9"
 $ trainable: logi TRUE
 $ dtype    : chr "float32"
 $ units    : num 64
new_layer <- Linear$from_config(config)

If you need more flexibility when deserializing the layer from its config, you can also override the from_config() class method. This is the base implementation of from_config():

from_config <- function(cls, config) do.call(cls, config)

To learn more about serialization and saving, see the complete guide to saving and serializing models.

Privileged training argument in the call() method

Some layers, in particular the BatchNormalization layer and the Dropout layer, have different behaviors during training and inference. For such layers, it is standard practice to expose a training (boolean) argument in the call() method.

By exposing this argument in call(), you enable the built-in training and evaluation loops (e.g. fit()) to correctly use the layer in training and inference. Note, the default of NULL means that the training parameter will be inferred by keras from the training context (e.g., it will be TRUE if called from fit(), FALSE if called from predict())

CustomDropout(keras$layers$Layer) %py_class% {
  initialize <- function(rate, ...) {
    super$initialize(...)
    self$rate <- rate
  }
  call <- function(inputs, training = NULL) {
    if (isTRUE(training)) {
      return(tf$nn$dropout(inputs, rate = self$rate))
    }
    inputs
  }
}

Privileged mask argument in the call() method

The other privileged argument supported by call() is the mask argument.

You will find it in all Keras RNN layers. A mask is a boolean tensor (one boolean value per timestep in the input) used to skip certain input timesteps when processing timeseries data.

Keras will automatically pass the correct mask argument to call() for layers that support it, when a mask is generated by a prior layer. Mask-generating layers are the Embedding layer configured with mask_zero=True, and the Masking layer.

To learn more about masking and how to write masking-enabled layers, please check out the guide “understanding padding and masking”.

The Model class

In general, you will use the Layer class to define inner computation blocks, and will use the Model class to define the outer model – the object you will train.

For instance, in a ResNet50 model, you would have several ResNet blocks subclassing Layer, and a single Model encompassing the entire ResNet50 network.

The Model class has the same API as Layer, with the following differences:

  • It has support for built-in training, evaluation, and prediction methods (fit(), evaluate(), predict()).
  • It exposes the list of its inner layers, via the model$layers property.
  • It exposes saving and serialization APIs (save_model_tf(), save_model_weights_tf(), …)

Effectively, the Layer class corresponds to what we refer to in the literature as a “layer” (as in “convolution layer” or “recurrent layer”) or as a “block” (as in “ResNet block” or “Inception block”).

Meanwhile, the Model class corresponds to what is referred to in the literature as a “model” (as in “deep learning model”) or as a “network” (as in “deep neural network”).

So if you’re wondering, “should I use the Layer class or the Model class?”, ask yourself: will I need to call fit() on it? Will I need to call save() on it? If so, go with Model. If not (either because your class is just a block in a bigger system, or because you are writing training & saving code yourself), use Layer.

For instance, we could take our mini-resnet example above, and use it to build a Model that we could train with fit(), and that we could save with save_model_weights_tf():

ResNet(keras$Model) %py_class% {
  initialize <- function(num_classes = 1000) {
    super$initialize()
    self$block_1 <- ResNetBlock()
    self$block_2 <- ResNetBlock()
    self$global_pool <- layer_global_average_pooling_2d()
    self$classifier <- layer_dense(units = num_classes)
  }

  call <- function(inputs) {
    x <- self$block_1(inputs)
    x <- self$block_2(x)
    x <- self$global_pool(x)
    self$classifier(x)
  }
}


resnet <- ResNet()
dataset <- ...
resnet %>% fit(dataset, epochs = 10)
resnet %>% save_model_tf(filepath)

Putting it all together: an end-to-end example

Here’s what you’ve learned so far:

  • A Layer encapsulates a state (created in initialize() or build()), and some computation (defined in call()).
  • Layers can be recursively nested to create new, bigger computation blocks.
  • Layers can create and track losses (typically regularization losses) as well as metrics, via add_loss() and add_metric()
  • The outer container, the thing you want to train, is a Model. A Model is just like a Layer, but with added training and serialization utilities.

Let’s put all of these things together into an end-to-end example: we’re going to implement a Variational AutoEncoder (VAE). We’ll train it on MNIST digits.

Our VAE will be a subclass of Model, built as a nested composition of layers that subclass Layer. It will feature a regularization loss (KL divergence).

Sampling(keras$layers$Layer) %py_class% {
  call <- function(inputs) {
    c(z_mean, z_log_var) %<-% inputs
    batch <- tf$shape(z_mean)[1]
    dim <- tf$shape(z_mean)[2]
    epsilon <- k_random_normal(shape = c(batch, dim))
    z_mean + exp(0.5 * z_log_var) * epsilon
  }
}


Encoder(keras$layers$Layer) %py_class% {
  "Maps MNIST digits to a triplet (z_mean, z_log_var, z)."

  initialize <- function(latent_dim = 32, intermediate_dim = 64, name = "encoder", ...) {
    super$initialize(name = name, ...)
    self$dense_proj <- layer_dense(units = intermediate_dim, activation = "relu")
    self$dense_mean <- layer_dense(units = latent_dim)
    self$dense_log_var <- layer_dense(units = latent_dim)
    self$sampling <- Sampling()
  }

  call <- function(inputs) {
    x <- self$dense_proj(inputs)
    z_mean <- self$dense_mean(x)
    z_log_var <- self$dense_log_var(x)
    z <- self$sampling(c(z_mean, z_log_var))
    list(z_mean, z_log_var, z)
  }
}


Decoder(keras$layers$Layer) %py_class% {
  "Converts z, the encoded digit vector, back into a readable digit."

  initialize <- function(original_dim, intermediate_dim = 64, name = "decoder", ...) {
    super$initialize(name = name, ...)
    self$dense_proj <- layer_dense(units = intermediate_dim, activation = "relu")
    self$dense_output <- layer_dense(units = original_dim, activation = "sigmoid")
  }

  call <- function(inputs) {
    x <- self$dense_proj(inputs)
    self$dense_output(x)
  }
}


VariationalAutoEncoder(keras$Model) %py_class% {
  "Combines the encoder and decoder into an end-to-end model for training."

  initialize <- function(original_dim, intermediate_dim = 64, latent_dim = 32,
                         name = "autoencoder", ...) {
    super$initialize(name = name, ...)
    self$original_dim <- original_dim
    self$encoder <- Encoder(
      latent_dim = latent_dim,
      intermediate_dim = intermediate_dim
    )
    self$decoder <- Decoder(original_dim, intermediate_dim = intermediate_dim)
  }

  call <- function(inputs) {
    c(z_mean, z_log_var, z) %<-% self$encoder(inputs)
    reconstructed <- self$decoder(z)
    # Add KL divergence regularization loss.
    kl_loss <- -0.5 * tf$reduce_mean(z_log_var - tf$square(z_mean) - tf$exp(z_log_var) + 1)
    self$add_loss(kl_loss)
    reconstructed
  }
}

Let’s write a simple training loop on MNIST:

library(tfautograph)
library(tfdatasets)


original_dim <- 784
vae <- VariationalAutoEncoder(original_dim, 64, 32)

optimizer <- optimizer_adam(learning_rate = 1e-3)
mse_loss_fn <- loss_mean_squared_error()

loss_metric <- metric_mean()

x_train <- dataset_mnist()$train$x %>%
  array_reshape(c(60000, 784)) %>%
  `/`(255)

train_dataset <- tensor_slices_dataset(x_train) %>%
  dataset_shuffle(buffer_size = 1024) %>%
  dataset_batch(64)

epochs <- 2

# Iterate over epochs.
for (epoch in seq(epochs)) {
  cat(sprintf("Start of epoch %d\n", epoch))

  # Iterate over the batches of the dataset.
  # autograph lets you use tfdatasets in `for` and `while`
  autograph({
    step <- 0
    for (x_batch_train in train_dataset) {
      with(tf$GradientTape() %as% tape, {
        ## Note: we're four opaque contexts deep here (for, autograph, for,
        ## with), When in doubt about the objects or methods that are available
        ## (e.g., what is `tape` here?), remember you can always drop into a
        ## debugger right here:
        # browser()

        reconstructed <- vae(x_batch_train)
        # Compute reconstruction loss
        loss <- mse_loss_fn(x_batch_train, reconstructed)

        loss %<>% add(vae$losses[[1]]) # Add KLD regularization loss
      })
      grads <- tape$gradient(loss, vae$trainable_weights)
      optimizer$apply_gradients(
        purrr::transpose(list(grads, vae$trainable_weights)))

      loss_metric(loss)

      step %<>% add(1)
      if (step %% 100 == 0) {
        cat(sprintf("step %d: mean loss = %.4f\n", step, loss_metric$result()))
      }
    }
  })
}
Start of epoch 1
step 100: mean loss = 0.1266
step 200: mean loss = 0.0997
step 300: mean loss = 0.0894
step 400: mean loss = 0.0844
step 500: mean loss = 0.0810
step 600: mean loss = 0.0789
step 700: mean loss = 0.0772
step 800: mean loss = 0.0760
step 900: mean loss = 0.0750
Start of epoch 2
step 100: mean loss = 0.0741
step 200: mean loss = 0.0736
step 300: mean loss = 0.0731
step 400: mean loss = 0.0727
step 500: mean loss = 0.0724
step 600: mean loss = 0.0720
step 700: mean loss = 0.0718
step 800: mean loss = 0.0715
step 900: mean loss = 0.0712

Note that since the VAE is subclassing Model, it features built-in training loops. So you could also have trained it like this:

vae <- VariationalAutoEncoder(784, 64, 32)

optimizer <- optimizer_adam(learning_rate = 1e-3)

vae %>% compile(optimizer, loss = loss_mean_squared_error())
vae %>% fit(x_train, x_train, epochs = 2, batch_size = 64)
Epoch 1/2
938/938 - 3s - loss: 0.0745 - 3s/epoch - 3ms/step
Epoch 2/2
938/938 - 2s - loss: 0.0676 - 2s/epoch - 2ms/step

Beyond object-oriented development: the Functional API

If you prefer a less object-oriented way of programming, you can also build models using the Functional API. Importantly, choosing one style or another does not prevent you from leveraging components written in the other style: you can always mix-and-match.

For instance, the Functional API example below reuses the same Sampling layer we defined in the example above:

original_dim <- 784
intermediate_dim <- 64
latent_dim <- 32

# Define encoder model.
original_inputs <- layer_input(shape = original_dim, name = "encoder_input")
x <- layer_dense(units = intermediate_dim, activation = "relu")(original_inputs)
z_mean <- layer_dense(units = latent_dim, name = "z_mean")(x)
z_log_var <- layer_dense(units = latent_dim, name = "z_log_var")(x)
z <- Sampling()(list(z_mean, z_log_var))
encoder <- keras_model(inputs = original_inputs, outputs = z, name = "encoder")

# Define decoder model.
latent_inputs <- layer_input(shape = latent_dim, name = "z_sampling")
x <- layer_dense(units = intermediate_dim, activation = "relu")(latent_inputs)
outputs <- layer_dense(units = original_dim, activation = "sigmoid")(x)
decoder <- keras_model(inputs = latent_inputs, outputs = outputs, name = "decoder")

# Define VAE model.
outputs <- decoder(z)
vae <- keras_model(inputs = original_inputs, outputs = outputs, name = "vae")

# Add KL divergence regularization loss.
kl_loss <- -0.5 * tf$reduce_mean(z_log_var - tf$square(z_mean) - tf$exp(z_log_var) + 1)
vae$add_loss(kl_loss)

# Train.
optimizer <- keras$optimizers$Adam(learning_rate = 1e-3)
vae %>% compile(optimizer, loss = loss_mean_squared_error())
vae %>% fit(x_train, x_train, epochs = 3, batch_size = 64)
Epoch 1/3
938/938 - 3s - loss: 0.0749 - 3s/epoch - 3ms/step
Epoch 2/3
938/938 - 2s - loss: 0.0676 - 2s/epoch - 2ms/step
Epoch 3/3
938/938 - 2s - loss: 0.0676 - 2s/epoch - 2ms/step

For more information, make sure to read the Functional API guide.

Defining custom layers and models in an R package

Unfortunately you can’t create references to Python objects at the top-level of an R package.

Here is why: when you build an R package, all the R files in the R/ directory get sourced in an R environment (the package namespace), and then that environment is saved as part of the package bundle. Loading the package means restoring the saved R environment. This means that the R code only gets sourced once, at build time. If you create references to external objects (e.g., Python objects) at package build time, they will be NULL pointers when the package is loaded, because the external objects they pointed to at build time no longer exist at load time.

The solution is to delay creating references to Python objects until run time. Fortunately, %py_class%, Layer(), and create_layer_wrapper(R6Class(...)) are all lazy about initializing the Python reference, so they are safe to define and export in an R package.

If you’re writing an R package that uses keras and reticulate, this article might be helpful to read over.

Summary

In this guide you learned about creating custom layers and models in keras.

  • The constructors available: new_layer_class(), %py_class%, create_layer_wrapper(), R6Class(), Layer().
  • What methods to you might want to define to your model: initialize(), build(), call(), and get_config().
  • What convenience methods are available when you subclass keras$layers$Layer: add_weight(), add_loss(), and add_metric()

Environment Details

tensorflow::tf_config()
TensorFlow v2.13.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.3.1 (2023-06-16)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so;  LAPACK version 3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: America/New_York
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] tfautograph_0.3.2.9000 keras_2.13.0.9000      tfdatasets_2.9.0.9000 
[4] tensorflow_2.13.0.9000 magrittr_2.0.3        

loaded via a namespace (and not attached):
 [1] vctrs_0.6.3            cli_3.6.1              knitr_1.43            
 [4] zeallot_0.1.0          rlang_1.1.1            xfun_0.40             
 [7] purrr_1.0.2            png_0.1-8              generics_0.1.3        
[10] jsonlite_1.8.7         glue_1.6.2             backports_1.4.1       
[13] htmltools_0.5.6        fansi_1.0.4            rmarkdown_2.24        
[16] grid_4.3.1             tfruns_1.5.1           evaluate_0.21         
[19] tibble_3.2.1           base64enc_0.1-3        fastmap_1.1.1         
[22] yaml_2.3.7             lifecycle_1.0.3        whisker_0.4.1         
[25] compiler_4.3.1         htmlwidgets_1.6.2      Rcpp_1.0.11           
[28] pkgconfig_2.0.3        rstudioapi_0.15.0      lattice_0.21-8        
[31] digest_0.6.33          R6_2.5.1               tidyselect_1.2.0      
[34] reticulate_1.31.0.9000 utf8_1.2.3             pillar_1.9.0          
[37] Matrix_1.5-4.1         tools_4.3.1           
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.4.0
array-record==0.4.1
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
bleach==6.0.0
cachetools==5.3.1
certifi==2023.7.22
charset-normalizer==3.2.0
click==8.1.7
decorator==5.1.1
dm-tree==0.1.8
etils==1.4.1
executing==1.2.0
flatbuffers==23.5.26
gast==0.4.0
google-auth==2.22.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
googleapis-common-protos==1.60.0
grpcio==1.57.0
h5py==3.9.0
idna==3.4
importlib-resources==6.0.1
ipython==8.14.0
jedi==0.19.0
kaggle==1.5.16
keras==2.13.1
keras-tuner==1.3.5
kt-legacy==1.0.5
libclang==16.0.6
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
numpy==1.24.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cudnn-cu11==8.6.0.163
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.1
pandas==2.0.3
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.0.0
promise==2.3
prompt-toolkit==3.0.39
protobuf==3.20.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pydot==1.4.2
Pygments==2.16.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-slugify==8.0.1
pytz==2023.3
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.11.2
six==1.16.0
stack-data==0.6.2
tensorboard==2.13.0
tensorboard-data-server==0.7.1
tensorflow==2.13.0
tensorflow-datasets==4.9.2
tensorflow-estimator==2.13.0
tensorflow-hub==0.14.0
tensorflow-io-gcs-filesystem==0.33.0
tensorflow-metadata==1.14.0
termcolor==2.3.0
text-unidecode==1.3
toml==0.10.2
tqdm==4.66.1
traitlets==5.9.0
typing_extensions==4.5.0
tzdata==2023.3
urllib3==1.26.16
wcwidth==0.2.6
webencodings==0.5.1
Werkzeug==2.3.7
wrapt==1.15.0
zipp==3.16.2
TF Devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU') 
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU') 
CPU cores: 12 
Date rendered: 2023-08-28 
Page render time: 2 minutes and 5 seconds