library(tensorflow)
library(keras)
set.seed(1234)
Timeseries classification with a Transformer model
Introduction
This is the Transformer architecture from Attention Is All You Need, applied to timeseries instead of natural language.
This example requires TensorFlow 2.4 or higher.
Load the dataset
We are going to use the same dataset and preprocessing as the TimeSeries Classification from Scratch example.
<- "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA"
url
<- "FordA_TRAIN.tsv" %>%
train_df get_file(., file.path(url, .)) %>%
::read_tsv(col_names = FALSE)
readr<- as.matrix(train_df[, -1])
x_train <- as.matrix(train_df[, 1])
y_train
<- "FordA_TEST.tsv" %>%
test_df get_file(., file.path(url, .)) %>%
::read_tsv(col_names = FALSE)
readr<- as.matrix(test_df[, -1])
x_test <- as.matrix(test_df[, 1])
y_test
<- length(unique(y_train))
n_classes
<- sample(nrow(x_train))
shuffle_ind <- x_train[shuffle_ind, , drop = FALSE]
x_train <- y_train[shuffle_ind, , drop = FALSE]
y_train
== -1] <- 0
y_train[y_train == -1] <- 0
y_test [y_test
dim(x_train) <- c(dim(x_train), 1)
dim(x_test) <- c(dim(x_test), 1)
Build a model
Our model processes a tensor of shape (batch size, sequence length, features)
, where sequence length
is the number of time steps and features
is each input timeseries.
You can replace your classification RNN layers with this one: the inputs are fully compatible!
We include residual connections, layer normalization, and dropout. The resulting layer can be stacked multiple times. The projection layers are implemented through layer_conv_1d()
.
<- function(inputs,
transformer_encoder
head_size,
num_heads,
ff_dim,dropout = 0) {
# Attention and Normalization
<-
attention_layer layer_multi_head_attention(key_dim = head_size,
num_heads = num_heads,
dropout = dropout)
<- dim(inputs) %>% tail(1)
n_features
<- inputs %>%
x attention_layer(., .) %>%
layer_dropout(dropout) %>%
layer_layer_normalization(epsilon = 1e-6)
<- x + inputs
res
# Feed Forward Part
<- res %>%
x layer_conv_1d(ff_dim, kernel_size = 1, activation = "relu") %>%
layer_dropout(dropout) %>%
layer_conv_1d(n_features, kernel_size = 1) %>%
layer_layer_normalization(epsilon = 1e-6)
# return output + residual
+ res
x
}
<- function(input_shape,
build_model
head_size,
num_heads,
ff_dim,
num_transformer_blocks,
mlp_units,dropout = 0,
mlp_dropout = 0) {
<- layer_input(input_shape)
inputs
<- inputs
x for (i in 1:num_transformer_blocks) {
<- x %>%
x transformer_encoder(
head_size = head_size,
num_heads = num_heads,
ff_dim = ff_dim,
dropout = dropout
)
}
<- x %>%
x layer_global_average_pooling_1d(data_format = "channels_first")
for (dim in mlp_units) {
<- x %>%
x layer_dense(dim, activation = "relu") %>%
layer_dropout(mlp_dropout)
}
<- x %>%
outputs layer_dense(n_classes, activation = "softmax")
keras_model(inputs, outputs)
}
Train and evaluate
<- dim(x_train)[-1] # drop batch dim
input_shape <- build_model(
model
input_shape,head_size = 256,
num_heads = 4,
ff_dim = 4,
num_transformer_blocks = 4,
mlp_units = c(128),
mlp_dropout = 0.4,
dropout = 0.25
)%>% compile(
model loss = "sparse_categorical_crossentropy",
optimizer = optimizer_adam(learning_rate = 1e-4),
metrics = c("sparse_categorical_accuracy")
)
model
Model: "model"
____________________________________________________________________________
Layer (type) Output Shape Param # Connected to
============================================================================
input_1 (InputLayer) [(None, 500, 1 0 []
)]
multi_head_attention (M (None, 500, 1) 7169 ['input_1[0][0]',
ultiHeadAttention) 'input_1[0][0]']
dropout (Dropout) (None, 500, 1) 0 ['multi_head_attention[0][
0]']
layer_normalization (La (None, 500, 1) 2 ['dropout[0][0]']
yerNormalization)
tf.math.add (TFOpLambda (None, 500, 1) 0 ['layer_normalization[0][0
) ]',
'input_1[0][0]']
conv1d_1 (Conv1D) (None, 500, 4) 8 ['tf.math.add[0][0]']
dropout_1 (Dropout) (None, 500, 4) 0 ['conv1d_1[0][0]']
conv1d (Conv1D) (None, 500, 1) 5 ['dropout_1[0][0]']
layer_normalization_1 ( (None, 500, 1) 2 ['conv1d[0][0]']
LayerNormalization)
tf.math.add_1 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_1[0]
da) [0]',
'tf.math.add[0][0]']
multi_head_attention_1 (None, 500, 1) 7169 ['tf.math.add_1[0][0]',
(MultiHeadAttention) 'tf.math.add_1[0][0]']
dropout_2 (Dropout) (None, 500, 1) 0 ['multi_head_attention_1[0
][0]']
layer_normalization_2 ( (None, 500, 1) 2 ['dropout_2[0][0]']
LayerNormalization)
tf.math.add_2 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_2[0]
da) [0]',
'tf.math.add_1[0][0]']
conv1d_3 (Conv1D) (None, 500, 4) 8 ['tf.math.add_2[0][0]']
dropout_3 (Dropout) (None, 500, 4) 0 ['conv1d_3[0][0]']
conv1d_2 (Conv1D) (None, 500, 1) 5 ['dropout_3[0][0]']
layer_normalization_3 ( (None, 500, 1) 2 ['conv1d_2[0][0]']
LayerNormalization)
tf.math.add_3 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_3[0]
da) [0]',
'tf.math.add_2[0][0]']
multi_head_attention_2 (None, 500, 1) 7169 ['tf.math.add_3[0][0]',
(MultiHeadAttention) 'tf.math.add_3[0][0]']
dropout_4 (Dropout) (None, 500, 1) 0 ['multi_head_attention_2[0
][0]']
layer_normalization_4 ( (None, 500, 1) 2 ['dropout_4[0][0]']
LayerNormalization)
tf.math.add_4 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_4[0]
da) [0]',
'tf.math.add_3[0][0]']
conv1d_5 (Conv1D) (None, 500, 4) 8 ['tf.math.add_4[0][0]']
dropout_5 (Dropout) (None, 500, 4) 0 ['conv1d_5[0][0]']
conv1d_4 (Conv1D) (None, 500, 1) 5 ['dropout_5[0][0]']
layer_normalization_5 ( (None, 500, 1) 2 ['conv1d_4[0][0]']
LayerNormalization)
tf.math.add_5 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_5[0]
da) [0]',
'tf.math.add_4[0][0]']
multi_head_attention_3 (None, 500, 1) 7169 ['tf.math.add_5[0][0]',
(MultiHeadAttention) 'tf.math.add_5[0][0]']
dropout_6 (Dropout) (None, 500, 1) 0 ['multi_head_attention_3[0
][0]']
layer_normalization_6 ( (None, 500, 1) 2 ['dropout_6[0][0]']
LayerNormalization)
tf.math.add_6 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_6[0]
da) [0]',
'tf.math.add_5[0][0]']
conv1d_7 (Conv1D) (None, 500, 4) 8 ['tf.math.add_6[0][0]']
dropout_7 (Dropout) (None, 500, 4) 0 ['conv1d_7[0][0]']
conv1d_6 (Conv1D) (None, 500, 1) 5 ['dropout_7[0][0]']
layer_normalization_7 ( (None, 500, 1) 2 ['conv1d_6[0][0]']
LayerNormalization)
tf.math.add_7 (TFOpLamb (None, 500, 1) 0 ['layer_normalization_7[0]
da) [0]',
'tf.math.add_6[0][0]']
global_average_pooling1 (None, 500) 0 ['tf.math.add_7[0][0]']
d (GlobalAveragePooling
1D)
dense (Dense) (None, 128) 64128 ['global_average_pooling1d
[0][0]']
dropout_8 (Dropout) (None, 128) 0 ['dense[0][0]']
dense_1 (Dense) (None, 2) 258 ['dropout_8[0][0]']
============================================================================
Total params: 93,130
Trainable params: 93,130
Non-trainable params: 0
____________________________________________________________________________
<- list(
callbacks callback_early_stopping(patience = 10, restore_best_weights = TRUE))
<- model %>%
history fit(
x_train,
y_train,batch_size = 64,
epochs = 200,
callbacks = callbacks,
validation_split = 0.2
)
%>% evaluate(x_test, y_test, verbose = 1) model
loss sparse_categorical_accuracy
0.3437351 0.8522727
Conclusions
In about 110-120 epochs (25s each on Colab), the model reaches a training accuracy of ~0.95, validation accuracy of ~84 and a testing accuracy of ~85, without hyperparameter tuning. And that is for a model with less than 100k parameters. Of course, parameter count and accuracy could be improved by a hyperparameter search and a more sophisticated learning rate schedule, or a different optimizer.
You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.