vq_vae

    This is the companion code to the post “Discrete Representation Learning with VQ-VAE and TensorFlow Probability” on the TensorFlow for R blog.

    https://blogs.rstudio.com/tensorflow/posts/2019-01-24-vq-vae/

    library(keras)
    use_implementation("tensorflow")
    library(tensorflow)
    tfe_enable_eager_execution(device_policy = "silent")
    
    use_session_with_seed(7778,
                          disable_gpu = FALSE,
                          disable_parallel_cpu = FALSE)
    
    tfp <- import("tensorflow_probability")
    tfd <- tfp$distributions
    
    library(tfdatasets)
    library(dplyr)
    library(glue)
    library(curry)
    
    moving_averages <- tf$python$training$moving_averages
    
    
    # Utilities --------------------------------------------------------
    
    visualize_images <-
      function(dataset,
               epoch,
               reconstructed_images,
               random_images) {
        write_png(dataset, epoch, "reconstruction", reconstructed_images)
        write_png(dataset, epoch, "random", random_images)
        
      }
    
    write_png <- function(dataset, epoch, desc, images) {
      png(paste0(dataset, "_epoch_", epoch, "_", desc, ".png"))
      par(mfcol = c(8, 8))
      par(mar = c(0.5, 0.5, 0.5, 0.5),
          xaxs = 'i',
          yaxs = 'i')
      for (i in 1:64) {
        img <- images[i, , , 1]
        img <- t(apply(img, 2, rev))
        image(
          1:28,
          1:28,
          img * 127.5 + 127.5,
          col = gray((0:255) / 255),
          xaxt = 'n',
          yaxt = 'n'
        )
      }
      dev.off()
      
    }
    
    
    # Setup and preprocessing -------------------------------------------------
    
    np <- import("numpy")
    
    # download from: https://github.com/rois-codh/kmnist
    kuzushiji <- np$load("kmnist-train-imgs.npz")
    kuzushiji <- kuzushiji$get("arr_0")
    
    train_images <- kuzushiji %>%
      k_expand_dims() %>%
      k_cast(dtype = "float32")
    train_images <- train_images %>% `/`(255)
    
    buffer_size <- 60000
    batch_size <- 64
    num_examples_to_generate <- batch_size
    
    batches_per_epoch <- buffer_size / batch_size
    
    train_dataset <- tensor_slices_dataset(train_images) %>%
      dataset_shuffle(buffer_size) %>%
      dataset_batch(batch_size, drop_remainder = TRUE)
    
    # test
    iter <- make_iterator_one_shot(train_dataset)
    batch <-  iterator_get_next(iter)
    batch %>% dim()
    
    # Params ------------------------------------------------------------------
    
    learning_rate <- 0.001
    latent_size <- 1
    num_codes <- 64L
    code_size <- 16L
    base_depth <- 32
    activation <- "elu"
    beta <- 0.25
    decay <- 0.99
    input_shape <- c(28, 28, 1)
    
    # Models -------------------------------------------------------------------
    
    default_conv <-
      set_defaults(layer_conv_2d, list(padding = "same", activation = activation))
    default_deconv <-
      set_defaults(layer_conv_2d_transpose,
                   list(padding = "same", activation = activation))
    
    # Encoder ------------------------------------------------------------------
    
    encoder_model <- function(name = NULL,
                              code_size) {
      
      keras_model_custom(name = name, function(self) {
        self$conv1 <- default_conv(filters = base_depth, kernel_size = 5)
        self$conv2 <-
          default_conv(filters = base_depth,
                       kernel_size = 5,
                       strides = 2)
        self$conv3 <-
          default_conv(filters = 2 * base_depth, kernel_size = 5)
        self$conv4 <-
          default_conv(
            filters = 2 * base_depth,
            kernel_size = 5,
            strides = 2
          )
        self$conv5 <-
          default_conv(
            filters = 4 * latent_size,
            kernel_size = 7,
            padding = "valid"
          )
        self$flatten <- layer_flatten()
        self$dense <- layer_dense(units = latent_size * code_size)
        self$reshape <-
          layer_reshape(target_shape = c(latent_size, code_size))
        
        function (x, mask = NULL) {
          x %>%
            # output shape:  7 28 28 32
            self$conv1() %>%
            # output shape:  7 14 14 32
            self$conv2() %>%
            # output shape:  7 14 14 64
            self$conv3() %>%
            # output shape:  7 7 7 64
            self$conv4() %>%
            # output shape:  7 1 1 4
            self$conv5() %>%
            # output shape:  7 4
            self$flatten() %>%
            # output shape:  7 16
            self$dense() %>%
            # output shape:  7 1 16
            self$reshape()
        }
        
      })
    }
    
    
    # Decoder ------------------------------------------------------------------
    
    decoder_model <- function(name = NULL,
                              input_size,
                              output_shape) {
      
      keras_model_custom(name = name, function(self) {
        self$reshape1 <- layer_reshape(target_shape = c(1, 1, input_size))
        self$deconv1 <-
          default_deconv(
            filters = 2 * base_depth,
            kernel_size = 7,
            padding = "valid"
          )
        self$deconv2 <-
          default_deconv(filters = 2 * base_depth, kernel_size = 5)
        self$deconv3 <-
          default_deconv(
            filters = 2 * base_depth,
            kernel_size = 5,
            strides = 2
          )
        self$deconv4 <-
          default_deconv(filters = base_depth, kernel_size = 5)
        self$deconv5 <-
          default_deconv(filters = base_depth,
                         kernel_size = 5,
                         strides = 2)
        self$deconv6 <-
          default_deconv(filters = base_depth, kernel_size = 5)
        self$conv1 <-
          default_conv(filters = output_shape[3],
                       kernel_size = 5,
                       activation = "linear")
        
        function (x, mask = NULL) {
          x <- x %>%
            # output shape:  7 1 1 16
            self$reshape1() %>%
            # output shape:  7 7 7 64
            self$deconv1() %>%
            # output shape:  7 7 7 64
            self$deconv2() %>%
            # output shape:  7 14 14 64
            self$deconv3() %>%
            # output shape:  7 14 14 32
            self$deconv4() %>%
            # output shape:  7 28 28 32
            self$deconv5() %>%
            # output shape:  7 28 28 32
            self$deconv6() %>%
            # output shape:  7 28 28 1
            self$conv1()
          tfd$Independent(tfd$Bernoulli(logits = x),
                          reinterpreted_batch_ndims = length(output_shape))
        }
      })
    }
    
    # Vector quantizer -------------------------------------------------------------------
    
    vector_quantizer_model <- 
      function(name = NULL, num_codes, code_size) {
        
        keras_model_custom(name = name, function(self) {
          self$num_codes <- num_codes
          self$code_size <- code_size
          self$codebook <- tf$get_variable("codebook",
                                           shape = c(num_codes, code_size),
                                           dtype = tf$float32)
          self$ema_count <- tf$get_variable(
            name = "ema_count",
            shape = c(num_codes),
            initializer = tf$constant_initializer(0),
            trainable = FALSE
          )
          self$ema_means = tf$get_variable(
            name = "ema_means",
            initializer = self$codebook$initialized_value(),
            trainable = FALSE
          )
          
          function (x, mask = NULL) {
    
            # bs * 1 * num_codes
            distances <- tf$norm(tf$expand_dims(x, axis = 2L) -
                                   tf$reshape(self$codebook,
                                              c(
                                                1L, 1L, self$num_codes, self$code_size
                                              )),
                                 axis = 3L)
            
            # bs * 1
            assignments <- tf$argmin(distances, axis = 2L)
            
            # bs * 1 * num_codes
            one_hot_assignments <-
              tf$one_hot(assignments, depth = self$num_codes)
            
            # bs * 1 * code_size
            nearest_codebook_entries <- tf$reduce_sum(
              tf$expand_dims(one_hot_assignments,-1L) * # bs, 1, 64, 1
                tf$reshape(self$codebook, c(
                  1L, 1L, self$num_codes, self$code_size
                )),
              axis = 2L # 1, 1, 64, 16
            )
            
            list(nearest_codebook_entries, one_hot_assignments)
          }
        })
      }
    
    
    # Update codebook ------------------------------------------------------
    
    update_ema <- function(vector_quantizer,
                           one_hot_assignments,
                           codes,
                           decay) {
      # shape = 64
      updated_ema_count <- moving_averages$assign_moving_average(
        vector_quantizer$ema_count,
        tf$reduce_sum(one_hot_assignments, axis = c(0L, 1L)),
        decay,
        zero_debias = FALSE
      )
      
      # 64 * 16
      updated_ema_means <- moving_averages$assign_moving_average(
        vector_quantizer$ema_means,
        # selects all assigned values (masking out the others) and sums them up over the batch
        # (will be divided by count later)
        tf$reduce_sum(
          tf$expand_dims(codes, 2L) *
            tf$expand_dims(one_hot_assignments, 3L),
          axis = c(0L, 1L)
        ),
        decay,
        zero_debias = FALSE
      )
      
      # Add small value to avoid dividing by zero
      updated_ema_count <- updated_ema_count + 1e-5
      updated_ema_means <-
        updated_ema_means / tf$expand_dims(updated_ema_count, axis = -1L)
      
      tf$assign(vector_quantizer$codebook, updated_ema_means)
    }
    
    
    # Training setup -----------------------------------------------------------
    
    encoder <- encoder_model(code_size = code_size)
    decoder <- decoder_model(input_size = latent_size * code_size,
                             output_shape = input_shape)
    
    vector_quantizer <-
      vector_quantizer_model(num_codes = num_codes, code_size = code_size)
    
    optimizer <- tf$train$AdamOptimizer(learning_rate = learning_rate)
    
    checkpoint_dir <- "./vq_vae_checkpoints"
    
    checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
    checkpoint <-
      tf$train$Checkpoint(
        optimizer = optimizer,
        encoder = encoder,
        decoder = decoder,
        vector_quantizer_model = vector_quantizer
      )
    
    checkpoint$save(file_prefix = checkpoint_prefix)
    
    # Training loop -----------------------------------------------------------
    
    num_epochs <- 20
    
    for (epoch in seq_len(num_epochs)) {
      
      iter <- make_iterator_one_shot(train_dataset)
      
      total_loss <- 0
      reconstruction_loss_total <- 0
      commitment_loss_total <- 0
      prior_loss_total <- 0
      
      until_out_of_range({
        
        x <-  iterator_get_next(iter)
        
        with(tf$GradientTape(persistent = TRUE) %as% tape, {
          
          codes <- encoder(x)
          c(nearest_codebook_entries, one_hot_assignments) %<-% vector_quantizer(codes)
          codes_straight_through <- codes + tf$stop_gradient(nearest_codebook_entries - codes)
          decoder_distribution <- decoder(codes_straight_through)
          
          reconstruction_loss <-
            -tf$reduce_mean(decoder_distribution$log_prob(x))
          
          commitment_loss <- tf$reduce_mean(tf$square(codes - tf$stop_gradient(nearest_codebook_entries)))
          
          prior_dist <- tfd$Multinomial(total_count = 1,
                                        logits = tf$zeros(c(latent_size, num_codes)))
          prior_loss <- -tf$reduce_mean(tf$reduce_sum(prior_dist$log_prob(one_hot_assignments), 1L))
          
          loss <-
            reconstruction_loss + beta * commitment_loss + prior_loss
          
        })
        
        encoder_gradients <- tape$gradient(loss, encoder$variables)
        decoder_gradients <- tape$gradient(loss, decoder$variables)
        
        optimizer$apply_gradients(purrr::transpose(list(
          encoder_gradients, encoder$variables
        )),
        global_step = tf$train$get_or_create_global_step())
        optimizer$apply_gradients(purrr::transpose(list(
          decoder_gradients, decoder$variables
        )),
        global_step = tf$train$get_or_create_global_step())
        
        update_ema(vector_quantizer,
                   one_hot_assignments,
                   codes,
                   decay)
        
        total_loss <- total_loss + loss
        reconstruction_loss_total <-
          reconstruction_loss_total + reconstruction_loss
        commitment_loss_total <- commitment_loss_total + commitment_loss
        prior_loss_total <- prior_loss_total + prior_loss
        
      })
      
      checkpoint$save(file_prefix = checkpoint_prefix)
      
      cat(
        glue(
          "Loss (epoch): {epoch}:",
          "  {(as.numeric(total_loss)/trunc(buffer_size/batch_size)) %>% round(4)} loss",
          "  {(as.numeric(reconstruction_loss_total)/trunc(buffer_size/batch_size)) %>% round(4)} reconstruction_loss",
          "  {(as.numeric(commitment_loss_total)/trunc(buffer_size/batch_size)) %>% round(4)} commitment_loss",
          "  {(as.numeric(prior_loss_total)/trunc(buffer_size/batch_size)) %>% round(4)} prior_loss",
          
        ),
        "\n"
      )
      
      # display example images (choose your frequency)
      if (TRUE) {
        reconstructed_images <- decoder_distribution$mean()
        # (64, 1, 16)
        prior_samples <- tf$reduce_sum(
          # selects one of the codes (masking out 63 of 64 codes)
          # (bs, 1, 64, 1)
          tf$expand_dims(prior_dist$sample(num_examples_to_generate),-1L) *
            # (1, 1, 64, 16)
            tf$reshape(vector_quantizer$codebook,
                       c(1L, 1L, num_codes, code_size)),
          axis = 2L
        )
        decoded_distribution_given_random_prior <-
          decoder(prior_samples)
        random_images <- decoded_distribution_given_random_prior$mean()
        visualize_images("k", epoch, reconstructed_images, random_images)
      }
    }