Timeseries classification with a Transformer model

timeseries
This notebook demonstrates how to do timeseries classification using a Transformer model.
Authors

Theodoros Ntakouris

terrytangyuan - R adaptation

t-kalinowski - R adaptation

Introduction

This is the Transformer architecture from Attention Is All You Need, applied to timeseries instead of natural language.

This example requires TensorFlow 2.4 or higher.

Load the dataset

We are going to use the same dataset and preprocessing as the TimeSeries Classification from Scratch example.

library(tensorflow)
library(keras)
set.seed(1234)
url <- "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA"

train_df <- "FordA_TRAIN.tsv" %>%
  get_file(., file.path(url, .)) %>%
  readr::read_tsv(col_names = FALSE)
x_train <- as.matrix(train_df[, -1])
y_train <- as.matrix(train_df[, 1])

test_df <- "FordA_TEST.tsv" %>%
  get_file(., file.path(url, .)) %>%
  readr::read_tsv(col_names = FALSE)
x_test <- as.matrix(test_df[, -1])
y_test <- as.matrix(test_df[, 1])

n_classes <- length(unique(y_train))

shuffle_ind <- sample(nrow(x_train))
x_train <- x_train[shuffle_ind, , drop = FALSE]
y_train <- y_train[shuffle_ind, , drop = FALSE]

y_train[y_train == -1] <- 0
y_test [y_test  == -1] <- 0

dim(x_train) <- c(dim(x_train), 1)
dim(x_test) <- c(dim(x_test), 1)

Build a model

Our model processes a tensor of shape (batch size, sequence length, features), where sequence length is the number of time steps and features is each input timeseries.

You can replace your classification RNN layers with this one: the inputs are fully compatible!

We include residual connections, layer normalization, and dropout. The resulting layer can be stacked multiple times. The projection layers are implemented through layer_conv_1d().

transformer_encoder <- function(inputs,
                                head_size,
                                num_heads,
                                ff_dim,
                                dropout = 0) {
  # Attention and Normalization
  attention_layer <-
    layer_multi_head_attention(key_dim = head_size,
                               num_heads = num_heads,
                               dropout = dropout)
  
  n_features <- dim(inputs) %>% tail(1)
  
  x <- inputs %>%
    attention_layer(., .) %>%
    layer_dropout(dropout) %>%
    layer_layer_normalization(epsilon = 1e-6)
  
  res <- x + inputs
  
  # Feed Forward Part
  x <- res %>%
    layer_conv_1d(ff_dim, kernel_size = 1, activation = "relu") %>%
    layer_dropout(dropout) %>%
    layer_conv_1d(n_features, kernel_size = 1) %>%
    layer_layer_normalization(epsilon = 1e-6)
  
  # return output + residual
  x + res
}


build_model <- function(input_shape,
                        head_size,
                        num_heads,
                        ff_dim,
                        num_transformer_blocks,
                        mlp_units,
                        dropout = 0,
                        mlp_dropout = 0) {
  
  inputs <- layer_input(input_shape)
  
  x <- inputs
  for (i in 1:num_transformer_blocks) {
    x <- x %>%
      transformer_encoder(
        head_size = head_size,
        num_heads = num_heads,
        ff_dim = ff_dim,
        dropout = dropout
      )
  }
  
  x <- x %>% 
    layer_global_average_pooling_1d(data_format = "channels_first")
  
  for (dim in mlp_units) {
    x <- x %>%
      layer_dense(dim, activation = "relu") %>%
      layer_dropout(mlp_dropout)
  }
  
  outputs <- x %>% 
    layer_dense(n_classes, activation = "softmax")
  
  keras_model(inputs, outputs)
}

Train and evaluate

input_shape <- dim(x_train)[-1] # drop batch dim
model <- build_model(
  input_shape,
  head_size = 256,
  num_heads = 4,
  ff_dim = 4,
  num_transformer_blocks = 4,
  mlp_units = c(128),
  mlp_dropout = 0.4,
  dropout = 0.25
)
model %>% compile(
  loss = "sparse_categorical_crossentropy",
  optimizer = optimizer_adam(learning_rate = 1e-4),
  metrics = c("sparse_categorical_accuracy")
)

model
Model: "model"
____________________________________________________________________________
 Layer (type)            Output Shape    Param #  Connected to              
============================================================================
 input_1 (InputLayer)    [(None, 500, 1  0        []                        
                         )]                                                 
 multi_head_attention (M  (None, 500, 1)  7169    ['input_1[0][0]',         
 ultiHeadAttention)                                'input_1[0][0]']         
 dropout (Dropout)       (None, 500, 1)  0        ['multi_head_attention[0][
                                                  0]']                      
 layer_normalization (La  (None, 500, 1)  2       ['dropout[0][0]']         
 yerNormalization)                                                          
 tf.math.add (TFOpLambda  (None, 500, 1)  0       ['layer_normalization[0][0
 )                                                ]',                       
                                                   'input_1[0][0]']         
 conv1d_1 (Conv1D)       (None, 500, 4)  8        ['tf.math.add[0][0]']     
 dropout_1 (Dropout)     (None, 500, 4)  0        ['conv1d_1[0][0]']        
 conv1d (Conv1D)         (None, 500, 1)  5        ['dropout_1[0][0]']       
 layer_normalization_1 (  (None, 500, 1)  2       ['conv1d[0][0]']          
 LayerNormalization)                                                        
 tf.math.add_1 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_1[0]
 da)                                              [0]',                     
                                                   'tf.math.add[0][0]']     
 multi_head_attention_1   (None, 500, 1)  7169    ['tf.math.add_1[0][0]',   
 (MultiHeadAttention)                              'tf.math.add_1[0][0]']   
 dropout_2 (Dropout)     (None, 500, 1)  0        ['multi_head_attention_1[0
                                                  ][0]']                    
 layer_normalization_2 (  (None, 500, 1)  2       ['dropout_2[0][0]']       
 LayerNormalization)                                                        
 tf.math.add_2 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_2[0]
 da)                                              [0]',                     
                                                   'tf.math.add_1[0][0]']   
 conv1d_3 (Conv1D)       (None, 500, 4)  8        ['tf.math.add_2[0][0]']   
 dropout_3 (Dropout)     (None, 500, 4)  0        ['conv1d_3[0][0]']        
 conv1d_2 (Conv1D)       (None, 500, 1)  5        ['dropout_3[0][0]']       
 layer_normalization_3 (  (None, 500, 1)  2       ['conv1d_2[0][0]']        
 LayerNormalization)                                                        
 tf.math.add_3 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_3[0]
 da)                                              [0]',                     
                                                   'tf.math.add_2[0][0]']   
 multi_head_attention_2   (None, 500, 1)  7169    ['tf.math.add_3[0][0]',   
 (MultiHeadAttention)                              'tf.math.add_3[0][0]']   
 dropout_4 (Dropout)     (None, 500, 1)  0        ['multi_head_attention_2[0
                                                  ][0]']                    
 layer_normalization_4 (  (None, 500, 1)  2       ['dropout_4[0][0]']       
 LayerNormalization)                                                        
 tf.math.add_4 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_4[0]
 da)                                              [0]',                     
                                                   'tf.math.add_3[0][0]']   
 conv1d_5 (Conv1D)       (None, 500, 4)  8        ['tf.math.add_4[0][0]']   
 dropout_5 (Dropout)     (None, 500, 4)  0        ['conv1d_5[0][0]']        
 conv1d_4 (Conv1D)       (None, 500, 1)  5        ['dropout_5[0][0]']       
 layer_normalization_5 (  (None, 500, 1)  2       ['conv1d_4[0][0]']        
 LayerNormalization)                                                        
 tf.math.add_5 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_5[0]
 da)                                              [0]',                     
                                                   'tf.math.add_4[0][0]']   
 multi_head_attention_3   (None, 500, 1)  7169    ['tf.math.add_5[0][0]',   
 (MultiHeadAttention)                              'tf.math.add_5[0][0]']   
 dropout_6 (Dropout)     (None, 500, 1)  0        ['multi_head_attention_3[0
                                                  ][0]']                    
 layer_normalization_6 (  (None, 500, 1)  2       ['dropout_6[0][0]']       
 LayerNormalization)                                                        
 tf.math.add_6 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_6[0]
 da)                                              [0]',                     
                                                   'tf.math.add_5[0][0]']   
 conv1d_7 (Conv1D)       (None, 500, 4)  8        ['tf.math.add_6[0][0]']   
 dropout_7 (Dropout)     (None, 500, 4)  0        ['conv1d_7[0][0]']        
 conv1d_6 (Conv1D)       (None, 500, 1)  5        ['dropout_7[0][0]']       
 layer_normalization_7 (  (None, 500, 1)  2       ['conv1d_6[0][0]']        
 LayerNormalization)                                                        
 tf.math.add_7 (TFOpLamb  (None, 500, 1)  0       ['layer_normalization_7[0]
 da)                                              [0]',                     
                                                   'tf.math.add_6[0][0]']   
 global_average_pooling1  (None, 500)    0        ['tf.math.add_7[0][0]']   
 d (GlobalAveragePooling                                                    
 1D)                                                                        
 dense (Dense)           (None, 128)     64128    ['global_average_pooling1d
                                                  [0][0]']                  
 dropout_8 (Dropout)     (None, 128)     0        ['dense[0][0]']           
 dense_1 (Dense)         (None, 2)       258      ['dropout_8[0][0]']       
============================================================================
Total params: 93,130
Trainable params: 93,130
Non-trainable params: 0
____________________________________________________________________________
callbacks <- list(
  callback_early_stopping(patience = 10, restore_best_weights = TRUE))

history <- model %>%
  fit(
    x_train,
    y_train,
    batch_size = 64,
    epochs = 200,
    callbacks = callbacks,
    validation_split = 0.2
  )

model %>% evaluate(x_test, y_test, verbose = 1)
                       loss sparse_categorical_accuracy 
                  0.3437351                   0.8522727 

Conclusions

In about 110-120 epochs (25s each on Colab), the model reaches a training accuracy of ~0.95, validation accuracy of ~84 and a testing accuracy of ~85, without hyperparameter tuning. And that is for a model with less than 100k parameters. Of course, parameter count and accuracy could be improved by a hyperparameter search and a more sophisticated learning rate schedule, or a different optimizer.

You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.