Trains a two-branch recurrent network on the bAbI dataset

nlp
Trains a two-branch recurrent network on the bAbI dataset for reading comprehension.

Trains two recurrent neural networks based upon a story and a question. The resulting merged vector is then queried to answer a range of bAbI tasks.

The results are comparable to those for an LSTM model provided in Weston et al.: “Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks” http://arxiv.org/abs/1502.05698

Task Number FB LSTM Baseline Keras QA
QA1 - Single Supporting Fact 50 100.0
QA2 - Two Supporting Facts 20 50.0
QA3 - Three Supporting Facts 20 20.5
QA4 - Two Arg. Relations 61 62.9
QA5 - Three Arg. Relations 70 61.9
QA6 - yes/No Questions 48 50.7
QA7 - Counting 49 78.9
QA8 - Lists/Sets 45 77.2
QA9 - Simple Negation 64 64.0
QA10 - Indefinite Knowledge 44 47.7
QA11 - Basic Coreference 72 74.9
QA12 - Conjunction 74 76.4
QA13 - Compound Coreference 94 94.4
QA14 - Time Reasoning 27 34.8
QA15 - Basic Deduction 21 32.4
QA16 - Basic Induction 23 50.6
QA17 - Positional Reasoning 51 49.1
QA18 - Size Reasoning 52 90.8
QA19 - Path Finding 8 9.0
QA20 - Agent’s Motivations 91 90.7

For the resources related to the bAbI project, refer to: https://research.facebook.com/researchers/1543934539189348

Notes:

library(keras)
library(readr)
library(stringr)
library(purrr)
library(tibble)
library(dplyr)

Attaching package: 'dplyr'
The following objects are masked from 'package:stats':

    filter, lag
The following objects are masked from 'package:base':

    intersect, setdiff, setequal, union
# Function definition -----------------------------------------------------

tokenize_words <- function(x){
  x <- x %>% 
    str_replace_all('([[:punct:]]+)', ' \\1') %>% 
    str_split(' ') %>%
    unlist()
  x[x != ""]
}

parse_stories <- function(lines, only_supporting = FALSE){
  lines <- lines %>% 
    str_split(" ", n = 2) %>%
    map_dfr(~tibble(nid = as.integer(.x[[1]]), line = .x[[2]]))
  
  lines <- lines %>%
    mutate(
      split = map(line, ~str_split(.x, "\t")[[1]]),
      q = map_chr(split, ~.x[1]),
      a = map_chr(split, ~.x[2]),
      supporting = map(split, ~.x[3] %>% str_split(" ") %>% unlist() %>% as.integer()),
      story_id = c(0, cumsum(nid[-nrow(.)] > nid[-1]))
    ) %>%
    select(-split)
  
  stories <- lines %>%
    filter(is.na(a)) %>%
    select(nid_story = nid, story_id, story = q)
  
  questions <- lines %>%
    filter(!is.na(a)) %>%
    select(-line) %>%
    left_join(stories, by = "story_id") %>%
    filter(nid_story < nid)

  if(only_supporting){
    questions <- questions %>%
      filter(map2_lgl(nid_story, supporting, ~.x %in% .y))
  }
    
  questions %>%
    group_by(story_id, nid, question = q, answer = a) %>%
    summarise(story = paste(story, collapse = " "), .groups = "keep") %>%
    ungroup() %>% 
    mutate(
      question = map(question, ~tokenize_words(.x)),
      story = map(story, ~tokenize_words(.x)),
      id = row_number()
    ) %>%
    select(id, question, answer, story)
}

vectorize_stories <- function(data, vocab, story_maxlen, query_maxlen){
  
  questions <- map(data$question, function(x){
    map_int(x, ~which(.x == vocab))
  })
  
  stories <- map(data$story, function(x){
    map_int(x, ~which(.x == vocab))
  })
  
  # "" represents padding
  answers <- sapply(c("", vocab), function(x){
    as.integer(x == data$answer)
  })
  

  list(
    questions = pad_sequences(questions, maxlen = query_maxlen),
    stories   = pad_sequences(stories, maxlen = story_maxlen),
    answers   = answers
  )
}

# Parameters --------------------------------------------------------------

max_length <- 99999
embed_hidden_size <- 50
batch_size <- 32
epochs <- 40

# Data Preparation --------------------------------------------------------

path <- get_file(
  fname = "babi-tasks-v1-2.tar.gz",
  origin = "https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz"
)
Loaded Tensorflow version 2.9.1
untar(path, exdir = str_replace(path, fixed(".tar.gz"), "/"))
path <- str_replace(path, fixed(".tar.gz"), "/")

# Default QA1 with 1000 samples
# challenge = '%stasks_1-20_v1-2/en/qa1_single-supporting-fact_%s.txt'
# QA1 with 10,000 samples
challenge = '%stasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_%s.txt'
# QA2 with 1000 samples
# challenge <- "%stasks_1-20_v1-2/en/qa2_two-supporting-facts_%s.txt"
# QA2 with 10,000 samples
# challenge = '%stasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_%s.txt'

train <- read_lines(sprintf(challenge, path, "train")) %>%
  parse_stories() %>%
  filter(map_int(story, ~length(.x)) <= max_length)

test <- read_lines(sprintf(challenge, path, "test")) %>%
  parse_stories() %>%
  filter(map_int(story, ~length(.x)) <= max_length)

# extract the vocabulary
all_data <- bind_rows(train, test)
vocab <- c(unlist(all_data$question), all_data$answer, 
           unlist(all_data$story)) %>%
  unique() %>%
  sort()

# Reserve 0 for masking via pad_sequences
vocab_size <- length(vocab) + 1
story_maxlen <- map_int(all_data$story, ~length(.x)) %>% max()
query_maxlen <- map_int(all_data$question, ~length(.x)) %>% max()

# vectorized versions of training and test sets
train_vec <- vectorize_stories(train, vocab, story_maxlen, query_maxlen)
test_vec <- vectorize_stories(test, vocab, story_maxlen, query_maxlen)

# Defining the model ------------------------------------------------------

sentence <- layer_input(shape = c(story_maxlen), dtype = "int32")
encoded_sentence <- sentence %>% 
  layer_embedding(input_dim = vocab_size, output_dim = embed_hidden_size) %>%
  layer_dropout(rate = 0.3)

question <- layer_input(shape = c(query_maxlen), dtype = "int32")
encoded_question <- question %>%
  layer_embedding(input_dim = vocab_size, output_dim = embed_hidden_size) %>%
  layer_dropout(rate = 0.3) %>%
  layer_lstm(units = embed_hidden_size) %>%
  layer_repeat_vector(n = story_maxlen)

merged <- list(encoded_sentence, encoded_question) %>%
  layer_add() %>%
  layer_lstm(units = embed_hidden_size) %>%
  layer_dropout(rate = 0.3)

preds <- merged %>%
  layer_dense(units = vocab_size, activation = "softmax")

model <- keras_model(inputs = list(sentence, question), outputs = preds)
model %>% compile(
  optimizer = "adam",
  loss = "categorical_crossentropy",
  metrics = "accuracy"
)

model
Model: "model"
____________________________________________________________________________
 Layer (type)            Output Shape    Param #  Connected to              
============================================================================
 input_2 (InputLayer)    [(None, 4)]     0        []                        
 embedding_1 (Embedding)  (None, 4, 50)  1100     ['input_2[0][0]']         
 input_1 (InputLayer)    [(None, 68)]    0        []                        
 dropout_1 (Dropout)     (None, 4, 50)   0        ['embedding_1[0][0]']     
 embedding (Embedding)   (None, 68, 50)  1100     ['input_1[0][0]']         
 lstm (LSTM)             (None, 50)      20200    ['dropout_1[0][0]']       
 dropout (Dropout)       (None, 68, 50)  0        ['embedding[0][0]']       
 repeat_vector (RepeatVe  (None, 68, 50)  0       ['lstm[0][0]']            
 ctor)                                                                      
 add (Add)               (None, 68, 50)  0        ['dropout[0][0]',         
                                                   'repeat_vector[0][0]']   
 lstm_1 (LSTM)           (None, 50)      20200    ['add[0][0]']             
 dropout_2 (Dropout)     (None, 50)      0        ['lstm_1[0][0]']          
 dense (Dense)           (None, 22)      1122     ['dropout_2[0][0]']       
============================================================================
Total params: 43,722
Trainable params: 43,722
Non-trainable params: 0
____________________________________________________________________________
# Training ----------------------------------------------------------------

model %>% fit(
  x = list(train_vec$stories, train_vec$questions),
  y = train_vec$answers,
  batch_size = batch_size,
  epochs = epochs,
  validation_split=0.05
)

evaluation <- model %>% evaluate(
  x = list(test_vec$stories, test_vec$questions),
  y = test_vec$answers,
  batch_size = batch_size
)

evaluation
      loss   accuracy 
0.03284298 0.99100000