nmt_attention

    This is the companion code to the post “Attention-based Neural Machine Translation with Keras” on the TensorFlow for R blog.

    https://blogs.rstudio.com/tensorflow/posts/2018-07-30-attention-layer/

    library(tensorflow)
    library(keras)
    library(tfdatasets)
    
    library(purrr)
    library(stringr)
    library(reshape2)
    library(viridis)
    library(ggplot2)
    library(tibble)
    
    
    # Preprocessing -----------------------------------------------------------
    
    # Assumes you've downloaded and unzipped one of the bilingual datasets offered at
    # http://www.manythings.org/anki/ and put it into a directory "data"
    # This example translates English to Dutch.
    
    filepath <- file.path("data", "nld.txt")
    
    lines <- readLines(filepath, n = 10000)
    sentences <- str_split(lines, "\t")
    
    space_before_punct <- function(sentence) {
      str_replace_all(sentence, "([?.!])", " \\1")
    }
    
    replace_special_chars <- function(sentence) {
      str_replace_all(sentence, "[^a-zA-Z?.!,¿]+", " ")
    }
    
    add_tokens <- function(sentence) {
      paste0("<start> ", sentence, " <stop>")
    }
    add_tokens <- Vectorize(add_tokens, USE.NAMES = FALSE)
    
    preprocess_sentence <- compose(add_tokens,
                                   str_squish,
                                   replace_special_chars,
                                   space_before_punct)
    
    word_pairs <- map(sentences, preprocess_sentence)
    
    create_index <- function(sentences) {
      unique_words <- sentences %>% unlist() %>% paste(collapse = " ") %>%
        str_split(pattern = " ") %>% .[[1]] %>% unique() %>% sort()
      index <- data.frame(
        word = unique_words,
        index = 1:length(unique_words),
        stringsAsFactors = FALSE
      ) %>%
        add_row(word = "<pad>",
                index = 0,
                .before = 1)
      index
    }
    
    word2index <- function(word, index_df) {
      index_df[index_df$word == word, "index"]
    }
    index2word <- function(index, index_df) {
      index_df[index_df$index == index, "word"]
    }
    
    src_index <- create_index(map(word_pairs, ~ .[[1]]))
    target_index <- create_index(map(word_pairs, ~ .[[2]]))
    sentence2digits <- function(sentence, index_df) {
      map((sentence %>% str_split(pattern = " "))[[1]], function(word)
        word2index(word, index_df))
    }
    
    sentlist2diglist <- function(sentence_list, index_df) {
      map(sentence_list, function(sentence)
        sentence2digits(sentence, index_df))
    }
    
    src_diglist <-
      sentlist2diglist(map(word_pairs, ~ .[[1]]), src_index)
    src_maxlen <- map(src_diglist, length) %>% unlist() %>% max()
    src_matrix <-
      pad_sequences(src_diglist, maxlen = src_maxlen,  padding = "post")
    
    target_diglist <-
      sentlist2diglist(map(word_pairs, ~ .[[2]]), target_index)
    target_maxlen <- map(target_diglist, length) %>% unlist() %>% max()
    target_matrix <-
      pad_sequences(target_diglist, maxlen = target_maxlen, padding = "post")
    
    
    
    # Train-test-split --------------------------------------------------------
    
    train_indices <-
      sample(nrow(src_matrix), size = nrow(src_matrix) * 0.8)
    
    validation_indices <- setdiff(1:nrow(src_matrix), train_indices)
    
    x_train <- src_matrix[train_indices,]
    y_train <- target_matrix[train_indices,]
    
    x_valid <- src_matrix[validation_indices,]
    y_valid <- target_matrix[validation_indices,]
    
    buffer_size <- nrow(x_train)
    
    # just for convenience, so we may get a glimpse at translation performance 
    # during training
    train_sentences <- sentences[train_indices]
    validation_sentences <- sentences[validation_indices]
    validation_sample <- sample(validation_sentences, 5)
    
    
    
    # Hyperparameters / variables ---------------------------------------------
    
    batch_size <- 32
    embedding_dim <- 64
    gru_units <- 256
    
    src_vocab_size <- nrow(src_index)
    target_vocab_size <- nrow(target_index)
    
    
    # Create datasets ---------------------------------------------------------
    
    train_dataset <-
      tensor_slices_dataset(keras_array(list(x_train, y_train)))  %>%
      dataset_shuffle(buffer_size = buffer_size) %>%
      dataset_batch(batch_size, drop_remainder = TRUE)
    
    validation_dataset <-
      tensor_slices_dataset(keras_array(list(x_valid, y_valid))) %>%
      dataset_shuffle(buffer_size = buffer_size) %>%
      dataset_batch(batch_size, drop_remainder = TRUE)
    
    
    # Attention encoder -------------------------------------------------------
    
    
    attention_encoder <-
      function(gru_units,
               embedding_dim,
               src_vocab_size,
               name = NULL) {
        keras_model_custom(name = name, function(self) {
          self$embedding <-
            layer_embedding(input_dim = src_vocab_size,
                            output_dim = embedding_dim)
          self$gru <-
            layer_gru(
              units = gru_units,
              return_sequences = TRUE,
              return_state = TRUE
            )
          
          function(inputs, mask = NULL) {
            x <- inputs[[1]]
            hidden <- inputs[[2]]
            
            x <- self$embedding(x)
            c(output, state) %<-% self$gru(x, initial_state = hidden)
            
            list(output, state)
          }
        })
      }
    
    
    
    # Attention decoder -------------------------------------------------------
    
    
    attention_decoder <-
      function(object,
               gru_units,
               embedding_dim,
               target_vocab_size,
               name = NULL) {
        keras_model_custom(name = name, function(self) {
          self$gru <-
            layer_gru(
              units = gru_units,
              return_sequences = TRUE,
              return_state = TRUE
            )
          self$embedding <-
            layer_embedding(input_dim = target_vocab_size, output_dim = embedding_dim)
          gru_units <- gru_units
          self$fc <- layer_dense(units = target_vocab_size)
          self$W1 <- layer_dense(units = gru_units)
          self$W2 <- layer_dense(units = gru_units)
          self$V <- layer_dense(units = 1L)
          
          function(inputs, mask = NULL) {
            x <- inputs[[1]]
            hidden <- inputs[[2]]
            encoder_output <- inputs[[3]]
            
            hidden_with_time_axis <- k_expand_dims(hidden, 2)
            
            score <-
              self$V(k_tanh(
                self$W1(encoder_output) + self$W2(hidden_with_time_axis)
              ))
            
            attention_weights <- k_softmax(score, axis = 2)
            
            context_vector <- attention_weights * encoder_output
            context_vector <- k_sum(context_vector, axis = 2)
            
            x <- self$embedding(x)
            
            x <-
              k_concatenate(list(k_expand_dims(context_vector, 2), x), axis = 3)
            
            c(output, state) %<-% self$gru(x)
            
            output <- k_reshape(output, c(-1, gru_units))
            
            x <- self$fc(output)
            
            list(x, state, attention_weights)
            
          }
          
        })
      }
    
    
    # The model ---------------------------------------------------------------
    
    encoder <- attention_encoder(
      gru_units = gru_units,
      embedding_dim = embedding_dim,
      src_vocab_size = src_vocab_size
    )
    
    decoder <- attention_decoder(
      gru_units = gru_units,
      embedding_dim = embedding_dim,
      target_vocab_size = target_vocab_size
    )
    
    optimizer <- tf$optimizers$Adam()
    
    cx_loss <- function(y_true, y_pred) {
      mask <- ifelse(y_true == 0L, 0, 1)
      loss <-
        tf$nn$sparse_softmax_cross_entropy_with_logits(labels = y_true,
                                                       logits = y_pred) * mask
      tf$reduce_mean(loss)
    }
    
    
    
    # Inference / translation functions ---------------------------------------
    # they are appearing here already in the file because we want to watch how
    # the network learns
    
    evaluate <-
      function(sentence) {
        attention_matrix <-
          matrix(0, nrow = target_maxlen, ncol = src_maxlen)
        
        sentence <- preprocess_sentence(sentence)
        input <- sentence2digits(sentence, src_index)
        input <-
          pad_sequences(list(input), maxlen = src_maxlen,  padding = "post")
        input <- k_constant(input)
        
        result <- ""
        
        hidden <- k_zeros(c(1, gru_units))
        c(enc_output, enc_hidden) %<-% encoder(list(input, hidden))
        
        dec_hidden <- enc_hidden
        dec_input <-
          k_expand_dims(list(word2index("<start>", target_index)))
        
        for (t in seq_len(target_maxlen - 1)) {
          c(preds, dec_hidden, attention_weights) %<-%
            decoder(list(dec_input, dec_hidden, enc_output))
          attention_weights <- k_reshape(attention_weights, c(-1))
          attention_matrix[t,] <- attention_weights %>% as.double()
          
          pred_idx <-
            tf$compat$v1$multinomial(k_exp(preds), num_samples = 1L)[1, 1] %>% as.double()
          pred_word <- index2word(pred_idx, target_index)
          
          if (pred_word == '<stop>') {
            result <-
              paste0(result, pred_word)
            return (list(result, sentence, attention_matrix))
          } else {
            result <-
              paste0(result, pred_word, " ")
            dec_input <- k_expand_dims(list(pred_idx))
          }
        }
        list(str_trim(result), sentence, attention_matrix)
      }
    
    plot_attention <-
      function(attention_matrix,
               words_sentence,
               words_result) {
        melted <- melt(attention_matrix)
        ggplot(data = melted, aes(
          x = factor(Var2),
          y = factor(Var1),
          fill = value
        )) +
          geom_tile() + scale_fill_viridis() + guides(fill = FALSE) +
          theme(axis.ticks = element_blank()) +
          xlab("") +
          ylab("") +
          scale_x_discrete(labels = words_sentence, position = "top") +
          scale_y_discrete(labels = words_result) +
          theme(aspect.ratio = 1)
      }
    
    
    translate <- function(sentence) {
      c(result, sentence, attention_matrix) %<-% evaluate(sentence)
      print(paste0("Input: ",  sentence))
      print(paste0("Predicted translation: ", result))
      attention_matrix <-
        attention_matrix[1:length(str_split(result, " ")[[1]]),
                         1:length(str_split(sentence, " ")[[1]])]
      plot_attention(attention_matrix,
                     str_split(sentence, " ")[[1]],
                     str_split(result, " ")[[1]])
    }
    
    # Training loop -----------------------------------------------------------
    
    
    n_epochs <- 50
    
    encoder_init_hidden <- k_zeros(c(batch_size, gru_units))
    
    for (epoch in seq_len(n_epochs)) {
      total_loss <- 0
      iteration <- 0
      
      iter <- make_iterator_one_shot(train_dataset)
      
      until_out_of_range({
        batch <- iterator_get_next(iter)
        loss <- 0
        x <- batch[[1]]
        y <- batch[[2]]
        iteration <- iteration + 1
    
        with(tf$GradientTape() %as% tape, {
          c(enc_output, enc_hidden) %<-% encoder(list(x, encoder_init_hidden))
          
          dec_hidden <- enc_hidden
          dec_input <-
            k_expand_dims(rep(list(
              word2index("<start>", target_index)
            ), batch_size))
          
          
          for (t in seq_len(target_maxlen - 1)) {
            c(preds, dec_hidden, weights) %<-%
              decoder(list(dec_input, dec_hidden, enc_output))
            loss <- loss + cx_loss(y[, t], preds)
            
            dec_input <- k_expand_dims(y[, t])
          }
        })
        total_loss <-
          total_loss + loss / k_cast_to_floatx(dim(y)[2])
        
        paste0(
          "Batch loss (epoch/batch): ",
          epoch,
          "/",
          iteration,
          ": ",
          (loss / k_cast_to_floatx(dim(y)[2])) %>% as.double() %>% round(4),
          "\n"
        ) %>% print()
        
        variables <- c(encoder$variables, decoder$variables)
        gradients <- tape$gradient(loss, variables)
        
        optimizer$apply_gradients(purrr::transpose(list(gradients, variables)))
        
      })
      
      paste0(
        "Total loss (epoch): ",
        epoch,
        ": ",
        (total_loss / k_cast_to_floatx(buffer_size)) %>% as.double() %>% round(4),
        "\n"
      ) %>% print()
      
      walk(train_sentences[1:5], function(pair)
        translate(pair[1]))
      walk(validation_sample, function(pair)
        translate(pair[1]))
    }
    
    # plot a mask
    example_sentence <- train_sentences[[1]]
    translate(example_sentence)