Customizing what happens in fit()

Introduction

When you’re doing supervised learning, you can use fit() and everything works smoothly.

When you need to write your own training loop from scratch, you can use the GradientTape and take control of every little detail.

But what if you need a custom training algorithm, but you still want to benefit from the convenient features of fit(), such as callbacks, built-in distribution support, or step fusing?

A core principle of Keras is progressive disclosure of complexity. You should always be able to get into lower-level workflows in a gradual way. You shouldn’t fall off a cliff if the high-level functionality doesn’t exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.

When you need to customize what fit() does, you should override the training step function of the Model class. This is the function that is called by fit() for every batch of data. You will then be able to call fit() as usual – and it will be running your own learning algorithm.

Note that this pattern does not prevent you from building models with the Functional API. You can do this whether you’re building Sequential models, Functional API models, or subclassed models.

Let’s see how that works.

Setup

Requires TensorFlow 2.2 or later.

library(tensorflow)
library(keras)

A first simple example

Let’s start from a simple example:

  • We create a new model class by calling new_model_class().
  • We just override the method train_step(data).
  • We return a dictionary mapping metric names (including the loss) to their current value.

The input argument data is what gets passed to fit as training data:

  • If you pass arrays, by calling fit(x, y, ...), then data will be the tuple (x, y)
  • If you pass a tf$data$Dataset, by calling fit(dataset, ...), then data will be what gets yielded by dataset at each batch.

In the body of the train_step method, we implement a regular training update, similar to what you are already familiar with. Importantly, we compute the loss via self$compiled_loss, which wraps the loss(es) function(s) that were passed to compile().

Similarly, we call self$compiled_metrics$update_state(y, y_pred) to update the state of the metrics that were passed in compile(), and we query results from self$metrics at the end to retrieve their current value.

CustomModel <- new_model_class(
  classname = "CustomModel",
  train_step = function(data) {
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
    c(x, y) %<-% data
    
    with(tf$GradientTape() %as% tape, {
      y_pred <- self(x, training = TRUE)  # Forward pass
      # Compute the loss value
      # (the loss function is configured in `compile()`)
      loss <-
        self$compiled_loss(y, y_pred, regularization_losses = self$losses)
    })
    
    # Compute gradients
    trainable_vars <- self$trainable_variables
    gradients <- tape$gradient(loss, trainable_vars)
    # Update weights
    self$optimizer$apply_gradients(zip_lists(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    self$compiled_metrics$update_state(y, y_pred)
    
    # Return a named list mapping metric names to current value
    results <- list()
    for (m in self$metrics)
      results[[m$name]] <- m$result()
    results
  }
)

Let’s try this out:

# Construct and compile an instance of CustomModel
inputs <- layer_input(shape(32))
outputs <- inputs %>%  layer_dense(1)
model <- CustomModel(inputs, outputs)
model %>% compile(optimizer = "adam",
                  loss = "mse",
                  metrics = "mae")

# Just use `fit` as usual
x <- k_random_uniform(c(1000, 32))
y <- k_random_uniform(c(1000, 1))
model %>% fit(x, y, epochs = 3)
Epoch 1/3
32/32 - 1s - loss: 2.1307 - mae: 1.3566 - 514ms/epoch - 16ms/step
Epoch 2/3
32/32 - 0s - loss: 1.0457 - mae: 0.8961 - 78ms/epoch - 2ms/step
Epoch 3/3
32/32 - 0s - loss: 0.5223 - mae: 0.5961 - 47ms/epoch - 1ms/step

Going lower-level

Naturally, you could just skip passing a loss function in compile(), and instead do everything manually in train_step. Likewise for metrics.

Here’s a lower-level example, that only uses compile() to configure the optimizer:

  • We start by creating Metric instances to track our loss and a MAE score.
  • We implement a custom train_step() that updates the state of these metrics (by calling update_state() on them), then query them (via result()) to return their current average value, to be displayed by the progress bar and to be pass to any callback.
  • Note that we would need to call reset_states() on our metrics between each epoch! Otherwise calling result() would return an average since the start of training, whereas we usually work with per-epoch averages. Thankfully, the framework can do that for us: just list any metric you want to reset in the metrics property of the model. The model will call reset_states() on any object listed here at the beginning of each fit() epoch or at the beginning of a call to evaluate().
loss_tracker <- metric_mean(name = "loss")
mae_metric <- metric_mean_absolute_error(name = "mae")

CustomModel <- new_model_class(
  classname = "CustomModel",
  train_step = function(data) {
    c(x, y) %<-% data
    
    with(tf$GradientTape() %as% tape, {
      y_pred <- self(x, training = TRUE)  # Forward pass
      # Compute our own loss
      loss <- keras$losses$mean_squared_error(y, y_pred)
    })
    
    # Compute gradients
    trainable_vars <- self$trainable_variables
    gradients <- tape$gradient(loss, trainable_vars)
    
    # Update weights
    self$optimizer$apply_gradients(zip_lists(gradients, trainable_vars))
    
    # Compute our own metrics
    loss_tracker$update_state(loss)
    mae_metric$update_state(y, y_pred)
    list(loss = loss_tracker$result(), 
         mae = mae_metric$result())
  },
  
  metrics = mark_active(function() {
    # We list our `Metric` objects here so that `reset_states()` can be
    # called automatically at the start of each epoch
    # or at the start of `evaluate()`.
    # If you don't implement this active property, you have to call
    # `reset_states()` yourself at the time of your choosing.
    list(loss_tracker, mae_metric)
  })
)


# Construct an instance of CustomModel
inputs <- layer_input(shape(32))
outputs <- inputs %>% layer_dense(1)
model <- CustomModel(inputs, outputs)

# We don't pass a loss or metrics here.
model %>% compile(optimizer = "adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x <- k_random_uniform(c(1000, 32))
y <- k_random_uniform(c(1000, 1))
model %>% fit(x, y, epochs = 5)
Epoch 1/5
32/32 - 0s - loss: 0.5503 - mae: 0.6102 - 341ms/epoch - 11ms/step
Epoch 2/5
32/32 - 0s - loss: 0.2866 - mae: 0.4246 - 64ms/epoch - 2ms/step
Epoch 3/5
32/32 - 0s - loss: 0.2568 - mae: 0.4032 - 56ms/epoch - 2ms/step
Epoch 4/5
32/32 - 0s - loss: 0.2500 - mae: 0.3975 - 54ms/epoch - 2ms/step
Epoch 5/5
32/32 - 0s - loss: 0.2437 - mae: 0.3927 - 53ms/epoch - 2ms/step

Supporting sample_weight & class_weight

You may have noticed that our first basic example didn’t make any mention of sample weighting. If you want to support the fit() arguments sample_weight and class_weight, you’d simply do the following:

  • Unpack sample_weight from the data argument
  • Pass it to compiled_loss & compiled_metrics (of course, you could also just apply it manually if you don’t rely on compile() for losses & metrics)
  • That’s it. That’s the list.
CustomModel <- new_model_class(
  classname = "CustomModel",
  train_step = function(data) {
    # Unpack the data. Its structure depends on your model and on what you pass
    # to `fit()`.  A third element in `data` is optional, but if present it's
    # assigned to sample_weight. If a thrid element is missing, sample_weight
    # defaults to NULL
    c(x, y, sample_weight = NULL) %<-% data
    
    with(tf$GradientTape() %as% tape, {
      y_pred <- self(x, training = TRUE)  # Forward pass
      # Compute the loss value.
      # The loss function is configured in `compile()`.
      loss <- self$compiled_loss(y,
                                 y_pred,
                                 sample_weight = sample_weight,
                                 regularization_losses = self$losses)
    })
    
    # Compute gradients
    trainable_vars <- self$trainable_variables
    gradients <- tape$gradient(loss, trainable_vars)
    
    # Update weights
    self$optimizer$apply_gradients(zip_lists(gradients, trainable_vars))
    
    # Update the metrics.
    # Metrics are configured in `compile()`.
    self$compiled_metrics$update_state(y, y_pred, sample_weight = sample_weight)
    
    # Return a named list mapping metric names to current value.
    # Note that it will include the loss (tracked in self$metrics).
    results <- list()
    for (m in self$metrics)
      results[[m$name]] <- m$result()
    results
  }
)


# Construct and compile an instance of CustomModel

inputs <- layer_input(shape(32))
outputs <- inputs %>% layer_dense(1)
model <- CustomModel(inputs, outputs)
model %>% compile(optimizer = "adam",
                  loss = "mse",
                  metrics = "mae")

# You can now use sample_weight argument

x <- k_random_uniform(c(1000, 32))
y <- k_random_uniform(c(1000, 1))
sw <- k_random_uniform(c(1000, 1))
model %>% fit(x, y, sample_weight = sw, epochs = 3)
Epoch 1/3
32/32 - 0s - loss: 0.5403 - mae: 0.9025 - 375ms/epoch - 12ms/step
Epoch 2/3
32/32 - 0s - loss: 0.2226 - mae: 0.5340 - 46ms/epoch - 1ms/step
Epoch 3/3
32/32 - 0s - loss: 0.1318 - mae: 0.4008 - 45ms/epoch - 1ms/step

Providing your own evaluation step

What if you want to do the same for calls to model$evaluate()? Then you would override test_step in exactly the same way. Here’s what it looks like:

CustomModel <- new_model_class(
  classname = "CustomModel",
  train_step = function(data) {
    # Unpack the data
    c(x, y) %<-% data
    # Compute predictions
    y_pred <- self(x, training = FALSE)
    # Updates the metrics tracking the loss
    self$compiled_loss(y, y_pred, regularization_losses = self$losses)
    # Update the metrics.
    self$compiled_metrics$update_state(y, y_pred)
    # Return a named list mapping metric names to current value.
    # Note that it will include the loss (tracked in self$metrics).
    results <- list()
    for (m in self$metrics)
      results[[m$name]] <- m$result()
    results
  }
)

# Construct an instance of CustomModel
inputs <- layer_input(shape(32))
outputs <- inputs %>% layer_dense(1)
model <- CustomModel(inputs, outputs)
model %>% compile(loss = "mse", metrics = "mae")

# Evaluate with our custom test_step
x <- k_random_uniform(c(1000, 32))
y <- k_random_uniform(c(1000, 1))
model %>% evaluate(x, y)
32/32 - 0s - loss: 1.2278 - mae: 1.0050 - 133ms/epoch - 4ms/step
    loss      mae 
1.227754 1.005029 

Wrapping up: an end-to-end GAN example

Let’s walk through an end-to-end example that leverages everything you just learned.

Let’s consider:

  • A generator network meant to generate 28x28x1 images.
  • A discriminator network meant to classify 28x28x1 images into two classes (“fake” and “real”).
  • One optimizer for each.
  • A loss function to train the discriminator.
# Create the discriminator
discriminator <-
  keras_model_sequential(name = "discriminator",
                         input_shape = c(28, 28, 1)) %>%
  layer_conv_2d(64, c(3, 3), strides = c(2, 2), padding = "same") %>%
  layer_activation_leaky_relu(alpha = 0.2) %>%
  layer_conv_2d(128, c(3, 3), strides = c(2, 2), padding = "same") %>%
  layer_activation_leaky_relu(alpha = 0.2) %>%
  layer_global_max_pooling_2d() %>%
  layer_dense(1)

# Create the generator
latent_dim <- 128
generator <- 
  keras_model_sequential(name = "generator",
                         input_shape = c(latent_dim)) %>%
  # We want to generate 128 coefficients to reshape into a 7x7x128 map
  layer_dense(7 * 7 * 128) %>%
  layer_activation_leaky_relu(alpha = 0.2) %>%
  layer_reshape(c(7, 7, 128)) %>%
  layer_conv_2d_transpose(128, c(4, 4), strides = c(2, 2), padding = "same") %>%
  layer_activation_leaky_relu(alpha = 0.2) %>%
  layer_conv_2d_transpose(128, c(4, 4), strides = c(2, 2), padding = "same") %>%
  layer_activation_leaky_relu(alpha = 0.2) %>%
  layer_conv_2d(1, c(7, 7), padding = "same", activation = "sigmoid")

Here’s a feature-complete GAN class, overriding compile() to use its own signature, and implementing the entire GAN algorithm in 17 lines in train_step:

GAN <- new_model_class(
  classname = "GAN",
  initialize = function(discriminator, generator, latent_dim) {
    super$initialize()
    self$discriminator <- discriminator
    self$generator <- generator
    self$latent_dim <- as.integer(latent_dim)
  },
  
  compile = function(d_optimizer, g_optimizer, loss_fn) {
    super$compile()
    self$d_optimizer <- d_optimizer
    self$g_optimizer <- g_optimizer
    self$loss_fn <- loss_fn
  },
  
  
  train_step = function(real_images) {
    # Sample random points in the latent space
    batch_size <- tf$shape(real_images)[1]
    random_latent_vectors <-
      tf$random$normal(shape = c(batch_size, self$latent_dim))
    
    # Decode them to fake images
    generated_images <- self$generator(random_latent_vectors)
    
    # Combine them with real images
    combined_images <-
      tf$concat(list(generated_images, real_images),
                axis = 0L)
    
    # Assemble labels discriminating real from fake images
    labels <-
      tf$concat(list(tf$ones(c(batch_size, 1L)),
                     tf$zeros(c(batch_size, 1L))),
                axis = 0L)
    
    # Add random noise to the labels - important trick!
    labels %<>% `+`(tf$random$uniform(tf$shape(.), maxval = 0.05))
    
    # Train the discriminator
    with(tf$GradientTape() %as% tape, {
      predictions <- self$discriminator(combined_images)
      d_loss <- self$loss_fn(labels, predictions)
    })
    grads <- tape$gradient(d_loss, self$discriminator$trainable_weights)
    self$d_optimizer$apply_gradients(
      zip_lists(grads, self$discriminator$trainable_weights))
    
    # Sample random points in the latent space
    random_latent_vectors <-
      tf$random$normal(shape = c(batch_size, self$latent_dim))
    
    # Assemble labels that say "all real images"
    misleading_labels <- tf$zeros(c(batch_size, 1L))
    
    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with(tf$GradientTape() %as% tape, {
      predictions <- self$discriminator(self$generator(random_latent_vectors))
      g_loss <- self$loss_fn(misleading_labels, predictions)
    })
    grads <- tape$gradient(g_loss, self$generator$trainable_weights)
    self$g_optimizer$apply_gradients(
      zip_lists(grads, self$generator$trainable_weights))
    
    list(d_loss = d_loss, g_loss = g_loss)
  }
)

Let’s test-drive it:

library(tfdatasets)
# Prepare the dataset. We use both the training & test MNIST digits.

batch_size <- 64
all_digits <- dataset_mnist() %>%
  { k_concatenate(list(.$train$x, .$test$x), axis = 1) } %>%
  k_cast("float32") %>%
  { . / 255 } %>%
  k_reshape(c(-1, 28, 28, 1))


dataset <- tensor_slices_dataset(all_digits) %>%
  dataset_shuffle(buffer_size = 1024) %>%
  dataset_batch(batch_size)

gan <-
  GAN(discriminator = discriminator,
      generator = generator,
      latent_dim = latent_dim)
gan %>% compile(
  d_optimizer = optimizer_adam(learning_rate = 0.0003),
  g_optimizer = optimizer_adam(learning_rate = 0.0003),
  loss_fn = loss_binary_crossentropy(from_logits = TRUE)
)

# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan %>% fit(dataset %>% dataset_take(100), epochs = 1)
100/100 - 4s - d_loss: 0.2360 - g_loss: 1.2044 - 4s/epoch - 44ms/step

Happy training!

Environment Details

tensorflow::tf_config()
TensorFlow v2.13.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.3.1 (2023-06-16)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so;  LAPACK version 3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: America/New_York
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] tfdatasets_2.9.0.9000  keras_2.13.0.9000      tensorflow_2.13.0.9000

loaded via a namespace (and not attached):
 [1] vctrs_0.6.3            cli_3.6.1              knitr_1.43            
 [4] zeallot_0.1.0          rlang_1.1.1            xfun_0.40             
 [7] png_0.1-8              generics_0.1.3         jsonlite_1.8.7        
[10] glue_1.6.2             htmltools_0.5.6        fansi_1.0.4           
[13] rmarkdown_2.24         grid_4.3.1             tfruns_1.5.1          
[16] evaluate_0.21          tibble_3.2.1           base64enc_0.1-3       
[19] fastmap_1.1.1          yaml_2.3.7             lifecycle_1.0.3       
[22] whisker_0.4.1          compiler_4.3.1         htmlwidgets_1.6.2     
[25] Rcpp_1.0.11            pkgconfig_2.0.3        rstudioapi_0.15.0     
[28] lattice_0.21-8         digest_0.6.33          R6_2.5.1              
[31] tidyselect_1.2.0       reticulate_1.31.0.9000 utf8_1.2.3            
[34] pillar_1.9.0           magrittr_2.0.3         Matrix_1.5-4.1        
[37] tools_4.3.1           
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.4.0
array-record==0.4.1
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
bleach==6.0.0
cachetools==5.3.1
certifi==2023.7.22
charset-normalizer==3.2.0
click==8.1.7
decorator==5.1.1
dm-tree==0.1.8
etils==1.4.1
executing==1.2.0
flatbuffers==23.5.26
gast==0.4.0
google-auth==2.22.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
googleapis-common-protos==1.60.0
grpcio==1.57.0
h5py==3.9.0
idna==3.4
importlib-resources==6.0.1
ipython==8.14.0
jedi==0.19.0
kaggle==1.5.16
keras==2.13.1
keras-tuner==1.3.5
kt-legacy==1.0.5
libclang==16.0.6
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
numpy==1.24.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cudnn-cu11==8.6.0.163
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.1
pandas==2.0.3
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.0.0
promise==2.3
prompt-toolkit==3.0.39
protobuf==3.20.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pydot==1.4.2
Pygments==2.16.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-slugify==8.0.1
pytz==2023.3
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.11.2
six==1.16.0
stack-data==0.6.2
tensorboard==2.13.0
tensorboard-data-server==0.7.1
tensorflow==2.13.0
tensorflow-datasets==4.9.2
tensorflow-estimator==2.13.0
tensorflow-hub==0.14.0
tensorflow-io-gcs-filesystem==0.33.0
tensorflow-metadata==1.14.0
termcolor==2.3.0
text-unidecode==1.3
toml==0.10.2
tqdm==4.66.1
traitlets==5.9.0
typing_extensions==4.5.0
tzdata==2023.3
urllib3==1.26.16
wcwidth==0.2.6
webencodings==0.5.1
Werkzeug==2.3.7
wrapt==1.15.0
zipp==3.16.2
TF Devices:
-  PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU') 
-  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU') 
CPU cores: 12 
Date rendered: 2023-08-28 
Page render time: 12 seconds