Advanced

This “Hello, World!” uses the Keras subclassing API and a custom training loop.

TensorFlow 2 quickstart for experts

The Keras functional and subclassing APIs provide a define-by-run interface for customization and advanced research. Build your model, then write the forward and backward pass. Create custom layers, activations, and training loops.

Import TensorFlow into your program. If you haven’t installed TensorFlow yet, go to the installation guide.

library(tensorflow)
library(tfdatasets)
library(keras)

Load and prepare the MNIST dataset.

c(c(x_train, y_train), c(x_test, y_test)) %<-% keras::dataset_mnist()
Loaded Tensorflow version 2.9.1
x_train %<>% { . / 255 }
x_test  %<>% { . / 255 }

Use TensorFlow Datasets to batch and shuffle the dataset:

train_ds <- list(x_train, y_train) %>%
  tensor_slices_dataset() %>%
  dataset_shuffle(10000) %>%
  dataset_batch(32)

test_ds <- list(x_test, y_test) %>%
  tensor_slices_dataset() %>%
  dataset_batch(32)

Build the a model using the Keras model subclassing API:

my_model <- new_model_class(
  classname = "MyModel",
  initialize = function(...) {
    super$initialize()
    self$conv1 <- layer_conv_2d(filters = 32, kernel_size = 3,
                                activation = 'relu')
    self$flatten <- layer_flatten()
    self$d1 <- layer_dense(units = 128, activation = 'relu')
    self$d2 <- layer_dense(units = 10)
  },
  call = function(inputs) {
    inputs %>%
      tf$expand_dims(3L) %>%
      self$conv1() %>%
      self$flatten() %>%
      self$d1() %>%
      self$d2()
  }
)

# Create an instance of the model
model <- my_model()

Choose an optimizer and loss function for training:

loss_object <- loss_sparse_categorical_crossentropy(from_logits = TRUE)
optimizer <- optimizer_adam()

Select metrics to measure the loss and the accuracy of the model. These metrics accumulate the values over epochs and then print the overall result.

train_loss <- metric_mean(name = "train_loss")
train_accuracy <- metric_sparse_categorical_accuracy(name = "train_accuracy")

test_loss <- metric_mean(name = "test_loss")
test_accuracy <- metric_sparse_categorical_accuracy(name = "test_accuracy")

Use tf$GradientTape() to train the model:

train_step <- function(images, labels) {
  with(tf$GradientTape() %as% tape, {
    # training = TRUE is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    predictions <- model(images, training = TRUE)
    loss <- loss_object(labels, predictions)
  })
  gradients <- tape$gradient(loss, model$trainable_variables)
  optimizer$apply_gradients(zip_lists(gradients, model$trainable_variables))
  train_loss(loss)
  train_accuracy(labels, predictions)
}

train <- tf_function(function(train_ds) {
  for (batch in train_ds) {
    c(images, labels) %<-% batch
    train_step(images, labels)
  }
})

Test the model:

test_step <- function(images, labels) {
  # training = FALSE is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  predictions <- model(images, training = FALSE)
  t_loss <- loss_object(labels, predictions)
  test_loss(t_loss)
  test_accuracy(labels, predictions)
}

test <- tf_function(function(test_ds) {
  for (batch in test_ds) {
    c(images, labels) %<-% batch
    test_step(images, labels)
  }
})
reset_metrics <- function() {
  for (metric in list(train_loss, train_accuracy,
                       test_loss, test_accuracy))
    metric$reset_states()
}
EPOCHS <- 1
for (epoch in seq_len(EPOCHS)) {
  # Reset the metrics at the start of the next epoch
  reset_metrics()
  train(train_ds)
  test(test_ds)
  cat(sprintf('Epoch %d', epoch), "\n")
  cat(sprintf('Loss: %f', train_loss$result()), "\n")
  cat(sprintf('Accuracy: %f', train_accuracy$result() * 100), "\n")
  cat(sprintf('Test Loss: %f', test_loss$result()), "\n")
  cat(sprintf('Test Accuracy: %f', test_accuracy$result() * 100), "\n")
}
Epoch 1 
Loss: 0.139603 
Accuracy: 95.820000 
Test Loss: 0.064584 
Test Accuracy: 97.839996 

The image classifier is now trained to ~98% accuracy on this dataset. To learn more, read the TensorFlow tutorials.