library(tensorflow)
library(keras)
library(tfdatasets)
Image Captioning
Setup
Download the dataset
We will be using the Flickr8K dataset for this tutorial. This dataset comprises over 8,000 images, that are each paired with five different captions.
<- get_file(
flickr_images "fickr8k.zip",
"https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
)
<- get_file(
flickr_text "flickr9k_text.zip",
"https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
)
if (!fs::dir_exists(fs::path(fs::path_dir(flickr_text), "Flicker8k_Dataset"))) {
unzip(flickr_images, exdir = fs::path_dir(flickr_images))
unzip(flickr_text, exdir = fs::path_dir(flickr_text))
}
# Path to the images
<- "Flicker8k_Dataset"
IMAGES_PATH
# Desired image dimensions
<- shape(299, 299)
IMAGE_SIZE
# Vocabulary size
<- 10000
VOCAB_SIZE
# Fixed length allowed for any sequence
<- 25
SEQ_LENGTH
# Dimension for the image embeddings and token embeddings
<- 512
EMBED_DIM
# Per-layer units in the feed-forward network
<- 512
FF_DIM
# Other training parameters
<- 64
BATCH_SIZE <- 30
EPOCHS <- tf$data$AUTOTUNE AUTOTUNE
Preparing the dataset
<- fs::path(fs::path_dir(flickr_text), "Flickr8k.token.txt") %>%
captions ::read_delim(
readrcol_names = c("img", "caption"),
delim = "\t"
%>%
) ::separate(img, into = c("img", "caption_id"), sep = "#") %>%
tidyr::select(img, caption) %>%
dplyr::group_by(img) %>%
dplyr::summarise(caption = list(caption)) %>%
dplyr::mutate(img = fs::path(fs::path_dir(flickr_text), "Flicker8k_Dataset", img))
dplyr
<- fs::path(fs::path_dir(flickr_text), "Flickr_8k.trainImages.txt") %>%
train ::read_lines()
readr
<- fs::path(fs::path_dir(flickr_text), "Flickr_8k.devImages.txt") %>%
valid ::read_lines()
readr
<- fs::path(fs::path_dir(flickr_text), "Flickr_8k.testImages.txt") %>%
test ::read_lines()
readr
<- captions %>%
train_data ::filter(fs::path_file(img) %in% train)
dplyr
<- captions %>%
valid_data ::filter(fs::path_file(img) %in% test)
dplyr
::n_distinct(train_data$img)
dplyr::n_distinct(valid_data$img) dplyr
Vectorizing the text data
We’ll use the text_vectorization
layer to vectorize the text data, that is to say, to turn the original strings into integer sequences where each integer represents the index of a word in a vocabulary. We will use a custom string standardization scheme (strip punctuation characters except <
and >
) and the default splitting scheme (split on whitespace).
<- c("!", "\\", "\"", "#", "$", "%", "&", "'", "(", ")", "*",
punctuation "+", ",", "-", ".", "/", ":", ";", "=", "?", "@", "[",
"\\", "\\", "]", "^", "_", "`", "{", "|", "}", "~")
<- reticulate::import("re")
re <- punctuation %>%
punctuation_group sapply(re$escape) %>%
paste0(collapse = "") %>%
sprintf("[%s]", .)
<- function(input_string) {
custom_standardization <- tf$strings$lower(input_string)
lowercase $strings$regex_replace(lowercase, punctuation_group, "")
tf
}
<- layer_text_vectorization(
vectorization max_tokens = VOCAB_SIZE,
output_mode = "int",
output_sequence_length = SEQ_LENGTH,
standardize = custom_standardization,
)%>% adapt(unlist(train_data$caption))
vectorization
# Data augmentation for image data
<- keras_model_sequential() %>%
image_augmentation layer_random_flip("horizontal") %>%
layer_random_rotation(0.2) %>%
layer_random_contrast(0.3)
Building a TensorFlow dataset pipeline for training
We will generate pairs of images and corresponding captions using a tf$data$Dataset
object. The pipeline consists of two steps:
- Read the image from the disk
- Tokenize all the five captions corresponding to the image
<- function(img_path) {
decode_and_resize %>%
img_path $io$read_file() %>%
tf$image$decode_jpeg(channels = 3) %>%
tf$image$resize(IMAGE_SIZE) %>%
tf$image$convert_image_dtype(tf$float32)
tf
}
<- function(img_path, captions) {
process_input ::tuple(
reticulatedecode_and_resize(img_path),
vectorization(captions)
)
}
<- function(data) {
make_dataset %>% unname() %>%
data tensor_slices_dataset() %>%
dataset_shuffle(nrow(data)) %>%
dataset_map(process_input, num_parallel_calls = AUTOTUNE) %>%
dataset_batch(BATCH_SIZE) %>%
dataset_prefetch(AUTOTUNE)
}
# Pass the list of images and the list of corresponding captions
<- make_dataset(train_data)
train_dataset <- make_dataset(valid_data) valid_dataset
Building the model
Our image captioning architecture consists of three models:
- A CNN: used to extract the image features
- A TransformerEncoder: The extracted image features are then passed to a Transformer based encoder that generates a new representation of the inputs
- A TransformerDecoder: This model takes the encoder output and the text data (sequences) as inputs and tries to learn to generate the caption.
<- function() {
get_cnn_model <- application_efficientnet_b0(
base_model input_shape = c(IMAGE_SIZE, 3),
include_top = FALSE,
weights = "imagenet"
)# We freeze our feature extractor
$trainable <- FALSE
base_model<- base_model$output %>%
base_model_out layer_reshape(target_shape = c(-1, tail(dim(base_model$output), 1)))
keras_model(base_model$input, base_model_out)
}
<- new_layer_class(
transformer_encoder_block "transformer_encoder_block",
initialize = function(embed_dim, dense_dim, num_heads, ...) {
super()$`__init__`(...)
$embed_dim <- embed_dim
self$dense_dim <- dense_dim
self$num_heads <- num_heads
self$attention_1 <- layer_multi_head_attention(
selfnum_heads = num_heads, key_dim = embed_dim, dropout = 0.0
)$layernorm_1 <- layer_normalization()
self$layernorm_2 <- layer_normalization()
self$dense_1 <- layer_dense(units = embed_dim, activation = "relu")
self
},call = function(inputs, training, mask = NULL) {
<- self$layernorm_1(inputs)
inputs <- self$dense_1(inputs)
inputs
<- self$attention_1(
attention_output_1 query = inputs,
value = inputs,
key = inputs,
attention_mask = NULL,
training = training,
)<- self$layernorm_2(inputs + attention_output_1)
out_1
out_1
}
)
<- new_layer_class(
positional_embedding "positional_embedding",
initialize = function(sequence_length, vocab_size, embed_dim, ...) {
super()$`__init__`(...)
$token_embeddings <- layer_embedding(
selfinput_dim = vocab_size, output_dim = embed_dim
)$position_embeddings <- layer_embedding(
selfinput_dim = sequence_length, output_dim = embed_dim
)$sequence_length <- sequence_length
self$vocab_size <- vocab_size
self$embed_dim <- embed_dim
self$embed_scale <- tf$math$sqrt(tf$cast(embed_dim, tf$float32))
self
},call = function(inputs) {
<- tail(dim(inputs), 1)
length <- tf$range(start = 0L, limit = length, delta = 1L)
positions <- self$token_embeddings(inputs)
embedded_tokens <- embedded_tokens * self$embed_scale
embedded_tokens <- self$position_embeddings(positions)
embedded_positions + embedded_positions
embedded_tokens
},compute_mask = function(inputs, mask) {
$math$not_equal(inputs, 0L)
tf
}
)
<- new_layer_class(
transformer_decoder_block "transformer_decoder_block",
initialize = function(embed_dim, ff_dim, num_heads, ...) {
super()$`__init__`(...)
$embed_dim <- embed_dim
self$ff_dim <- ff_dim
self$num_heads <- num_heads
self$attention_1 <- layer_multi_head_attention(
selfnum_heads = num_heads, key_dim = embed_dim, dropout = 0.1
)$attention_2 <- layer_multi_head_attention(
selfnum_heads = num_heads, key_dim = embed_dim, dropout = 0.1
)$ffn_layer_1 <- layer_dense(units = ff_dim, activation = "relu")
self$ffn_layer_2 <- layer_dense(units = embed_dim)
self
$layernorm_1 <- layer_normalization()
self$layernorm_2 <- layer_normalization()
self$layernorm_3 <- layer_normalization()
self
$embedding <- positional_embedding(
selfembed_dim = EMBED_DIM, sequence_length = SEQ_LENGTH, vocab_size = VOCAB_SIZE
)$out <- layer_dense(units = VOCAB_SIZE, activation = "softmax")
self
$dropout_1 <- layer_dropout(rate = 0.3)
self$dropout_2 <- layer_dropout(rate = 0.5)
self$supports_masking <- TRUE
self
},call = function(inputs, encoder_outputs, training, mask = NULL) {
<- self$embedding(inputs)
inputs <- self$get_causal_attention_mask(inputs)
causal_mask
if(!is.null(mask)) {
<- tf$cast(mask[, , tf$newaxis], dtype = tf$int32)
padding_mask <- tf$cast(mask[, tf$newaxis, ], dtype = tf$int32)
combined_mask <- tf$minimum(combined_mask, causal_mask)
combined_mask
}
<- self$attention_1(
attention_output_1 query = inputs,
value = inputs,
key = inputs,
attention_mask = combined_mask,
training = training,
)<- self$layernorm_1(inputs + attention_output_1)
out_1
<- self$attention_2(
attention_output_2 query = out_1,
value = encoder_outputs,
key = encoder_outputs,
attention_mask = padding_mask,
training = training,
)<- self$layernorm_2(out_1 + attention_output_2)
out_2
<- self$ffn_layer_1(out_2)
ffn_out <- self$dropout_1(ffn_out, training = training)
ffn_out <- self$ffn_layer_2(ffn_out)
ffn_out
<- self$layernorm_3(ffn_out + out_2, training = training)
ffn_out <- self$dropout_2(ffn_out, training = training)
ffn_out <- self$out(ffn_out)
preds
preds
},get_causal_attention_mask = function(inputs) {
<- tf$shape(inputs)
input_shape <- input_shape[1]
batch_size <- input_shape[2]
sequence_length <- tf$range(sequence_length)[, tf$newaxis]
i <- tf$range(sequence_length)
j <- tf$cast(i >= j, dtype = "int32")
mask <- tf$reshape(mask, list(1L, input_shape[2], input_shape[2]))
mask <- tf$concat(list(
mult $expand_dims(batch_size, -1L),
tfas_tensor(c(1L, 1L), dtype = tf$int32)
axis = 0L)
), $tile(mask, mult)
tf
}
)
<- new_model_class(
image_captioning_model "image_captioning_model",
initialize = function(cnn_model, encoder, decoder, num_captions_per_image = 5,
image_aug = NULL) {
super()$`__init__`()
$cnn_model <- cnn_model
self$encoder <- encoder
self$decoder <- decoder
self$loss_tracker <- metric_mean(name = "loss")
self$acc_tracker <- metric_mean(name = "accuracy")
self$num_captions_per_image <- num_captions_per_image
self$image_aug <- image_aug
self
},calculate_loss = function(y_true, y_pred, mask) {
<- self$loss(y_true, y_pred)
loss <- tf$cast(mask, dtype = loss$dtype)
mask <- loss* mask
loss $reduce_sum(loss) / tf$reduce_sum(mask)
tf
},calculate_accuracy = function(y_true, y_pred, mask) {
<- tf$equal(y_true, tf$argmax(y_pred, axis = 2L))
accuracy <- tf$math$logical_and(mask, accuracy)
accuracy <- tf$cast(accuracy, dtype = tf$float32)
accuracy <- tf$cast(mask, dtype = tf$float32)
mask $reduce_sum(accuracy) / tf$reduce_sum(mask)
tf
},.compute_caption_loss_and_acc = function(img_embed, batch_seq, training = TRUE) {
<- self$encoder(img_embed, training = training)
encoder_out <- batch_seq[, NULL:-2]
batch_seq_inp <- batch_seq[, 2:NULL]
batch_seq_true <- tf$math$not_equal(batch_seq_true, 0L)
mask <- self$decoder(
batch_seq_pred training = training, mask = mask
batch_seq_inp, encoder_out,
)<- self$calculate_loss(batch_seq_true, batch_seq_pred, mask)
loss <- self$calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
acc list(loss, acc)
},train_step = function(batch_data) {
<- batch_data[[1]]
batch_img <- batch_data[[2]]
batch_seq <- 0
batch_loss <- 0
batch_acc
if (!is.null(self$image_aug)){
<- self$image_aug(batch_img)
batch_img
}
# 1. Get image embeddings
<- self$cnn_model(batch_img)
img_embed
# 2. Pass each of the five captions one by one to the decoder
# along with the encoder outputs and compute the loss as well as accuracy
# for each caption.
for (i in seq_len(self$num_captions_per_image)) {
with(tf$GradientTape() %as% tape, {
c(loss, acc) %<-% self$.compute_caption_loss_and_acc(
training = TRUE
img_embed, batch_seq[, i, ],
)
# 3. Update loss and accuracy
<- batch_loss + loss
batch_loss <- batch_acc + acc
batch_acc
})
# 4. Get the list of all the trainable weights
<- c(self$encoder$trainable_variables,
train_vars $decoder$trainable_variables)
self
# 5. Get the gradients
<- tape$gradient(loss, train_vars)
grads
# 6. Update the trainable weights
$optimizer$apply_gradients(zip_lists(grads, train_vars))
self
}
# 7. Update the trackers
<- batch_acc/self$num_captions_per_image
batch_acc $loss_tracker$update_state(batch_loss)
self$acc_tracker$update_state(batch_acc)
self
# 8. Return the loss and accuracy values
list(
loss = self$loss_tracker$result(),
acc = self$acc_tracker$result()
)
},test_step = function(batch_data) {
<- batch_data[[1]]
batch_img <- batch_data[[2]]
batch_seq <- 0
batch_loss <- 0
batch_acc
# 1. Get image embeddings
<- self$cnn_model(batch_img)
img_embed
# 2. Pass each of the five captions one by one to the decoder
# along with the encoder outputs and compute the loss as well as accuracy
# for each caption.
for (i in seq_len(self$num_captions_per_image)) {
with(tf$GradientTape() %as% tape, {
c(loss, acc) %<-% self$.compute_caption_loss_and_acc(
training = TRUE
img_embed, batch_seq[, i, ],
)
# 3. Update loss and accuracy
<- batch_loss + loss
batch_loss <- batch_acc + acc
batch_acc
})
}
<- batch_acc / self$num_captions_per_image
batch_acc
# 4. Update the trackers
$loss_tracker$update_state(batch_loss)
self$acc_tracker$update_state(batch_acc)
self
# 5. Return the loss and accuracy values
list(
"loss" = self$loss_tracker$result(),
"acc" = self$acc_tracker$result()
)
},metrics = mark_active(function() {
# We need to list our metrics here so the `reset_states()` can be
# called automatically.
list(self$loss_tracker, self$acc_tracker)
})
)
<- get_cnn_model()
cnn_model <- transformer_encoder_block(embed_dim = EMBED_DIM, dense_dim = FF_DIM, num_heads = 1)
encoder <- transformer_decoder_block(embed_dim = EMBED_DIM, ff_dim = FF_DIM, num_heads = 2)
decoder <- image_captioning_model(
caption_model cnn_model = cnn_model,
encoder = encoder,
decoder = decoder,
image_aug = image_augmentation
)
Model training
# Define the loss function
<- loss_sparse_categorical_crossentropy(
cross_entropy from_logits = FALSE, reduction = "none"
)
# EarlyStopping criteria
<- callback_early_stopping(patience = 3, restore_best_weights = TRUE)
early_stopping
# Learning Rate Scheduler for the optimizer
<- new_learning_rate_schedule_class(
lr_schedule "lr_schedule",
initialize = function(post_warmup_learning_rate, warmup_steps) {
super()$`__init__`()
$post_warmup_learning_rate <- post_warmup_learning_rate
self$warmup_steps <- warmup_steps
self
},call = function(step) {
<- tf$cast(step, tf$float32)
global_step <- tf$cast(self$warmup_steps, tf$float32)
warmup_steps <- global_step / warmup_steps
warmup_progress <- self$post_warmup_learning_rate * warmup_progress
warmup_learning_rate $cond(
tf< warmup_steps,
global_step function() warmup_learning_rate,
function() self$post_warmup_learning_rate
)
}
)
# Create a learning rate schedule
<- length(train_dataset) * EPOCHS
num_train_steps <- num_train_steps %/% 15
num_warmup_steps <- lr_schedule(post_warmup_learning_rate = 1e-4, warmup_steps = num_warmup_steps)
lr
# Compile the model
%>% compile(
caption_model optimizer = optimizer_adam(learning_rate = lr),
loss = cross_entropy
)
# Fit the model
%>% fit(
caption_model
train_dataset,epochs = EPOCHS,
validation_data = valid_dataset,
callbacks = list(early_stopping)
)
Check sample predictions
<- get_vocabulary(vectorization)
vocab <- SEQ_LENGTH - 1
max_decoded_sentence_length <- valid_data$img
valid_images
<- function() {
generate_caption # Select a random image from the validation dataset
<- sample(valid_images, 1)
sample_img
# Read the image from the disk
<- decode_and_resize(sample_img)
sample_img <- as.array(tf$clip_by_value(sample_img, 0, 255))
img %>% as.raster(max = 255) %>% plot()
img
# Pass the image to the CNN
<- tf$expand_dims(sample_img, 0L)
img <- caption_model$cnn_model(img)
img
# Pass the image features to the Transformer encoder
<- caption_model$encoder(img, training = FALSE)
encoded_img
# Generate the caption using the Transformer decoder
<- "<start> "
decoded_caption for (i in seq_len(max_decoded_sentence_length)) {
<- vectorization(list(decoded_caption))
tokenized_caption <- tf$math$not_equal(tokenized_caption, 0L)
mask <- caption_model$decoder(
predictions training = FALSE, mask = mask
tokenized_caption, encoded_img,
)<- tf$argmax(predictions[1, i, ])
sampled_token_index <- vocab[as.integer(sampled_token_index) + 1]
sampled_token
if (sampled_token == " <end>") {
break
}
<- paste(decoded_caption, sampled_token, sep = " ")
decoded_caption
}
cat("Predicted Caption: ", decoded_caption)
}
# Check predictions for a few samples
generate_caption()
generate_caption()
generate_caption()
End Notes
We saw that the model starts to generate reasonable captions after a few epochs. To keep this example easily runnable, we have trained it with a few constraints, like a minimal number of attention heads. To improve the predictions, you can try changing these training settings and find a good model for your use case.