Writing Custom Keras Wrappers

Overview

Use cases for custom wrappers arise less often than for custom models or custom layers. In contrast to standalone layers, custom wrappers modify the behavior of an underlying layer.

Currently Keras provides two specialized wrappers, bidirectional and time_distributed. The R6 class KerasWrapper allows subclasses to implement specialized layer-wrapping logic.

Creating a Custom Wrapper

Here is a simple subclass that adds a weight and a loss to a wrapped layer.

CustomWrapper <- R6::R6Class(
  "CustomWrapper",
  
  inherit = KerasWrapper,
  
  public = list(
    weight_shape = NULL,
    weight_init = NULL,
    custom_weight = NULL,
    
    initialize = function(weight_shape, weight_init) {
      self$weight_shape <- weight_shape
      self$weight_init <- weight_init
    },
    
    build = function(input_shape) {
      
      # call this before doing any customization
      super$build(input_shape)
      
      self$custom_weight <- super$add_weight(
        name = "custom_weight",
        shape = self$weight_shape,
        initializer = self$weight_init,
        trainable = TRUE
      )
      
      regularizer <- k_log(self$custom_weight)
      super$add_loss(regularizer)
      
    }
  )
)

Instantiating a Custom Wrapper

Just like custom layers have instantiator functions, create an instantiator for the CustomWrapper class.

wrapper_custom <-
  function(object,
           layer,
           weight_shape,
           weight_init,
           input_shape = NULL) {
    create_wrapper(
      CustomWrapper,
      object,
      list(
        layer = layer,
        weight_shape = weight_shape,
        weight_init = weight_init,
        input_shape = input_shape
      )
    )
  }

Using the Custom Wrapper

Now you can use the wrapper in a Keras model like one of the existing wrappers.

model <- keras_model_sequential() %>%
  wrapper_custom(
    layer = layer_dense(units = 4),
    weight_shape = shape(1),
    weight_init = initializer_he_normal(),
    input_shape = shape(2)
  ) %>%
  wrapper_custom(
    layer = layer_dense(units = 2),
    weight_shape = shape(1),
    weight_init = initializer_he_normal()
  ) %>%
  layer_dense(units = 1)
  
model %>% compile(optimizer = "adam", loss = "mse")
  
model %>% fit(
  x = matrix(1:10, ncol = 2),
  y = matrix(1:5, ncol = 1),
  batch_size = 1,
  epochs = 1
)