library(tensorflow)
library(keras)
stop("STOP")DCGAN to generate face images
generative
A simple DCGAN trained using
fit() by overriding train_step on CelebA images.
Setup
Prepare CelebA data
We’ll use face images from the CelebA dataset, resized to 64x64.
dataset_path <- fs::path("~/datasets/celeba_gan")
# output <- "celeba_gan/"
if (!fs::dir_exists(dataset_path)) {
fs::dir_create(fs::path_dir(dataset_path))
url <- "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
reticulate::import("gdown")$download(url, output, quiet = TRUE)
unzip(output, exdir = fs::path_dir(output))
}
dataset_path <- fs::path(fs::path_dir(output), "img_align_celeba")Create a dataset from our folder:
dataset <- image_dataset_from_directory(
dataset_path,
image_size = c(64, 64),
label_mode = NULL,
batch_size = 32
)
dataset <- dataset$apply(tf$data$experimental$ignore_errors(
log_warning=FALSE
))Let’s display a sample image:
dataset %>%
reticulate::as_iterator() %>%
reticulate::iter_next() %>%
as.array() %>%
{.[1,,,]} %>%
as.raster(max = 255) %>%
plot()Create the discriminator
It maps a 64x64 image to a binary classification score.
discriminator <- keras_model_sequential(name = "discriminator", input_shape = shape(64, 64, 3)) %>%
layer_conv_2d(64, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d(128, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d(128, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_flatten() %>%
layer_dropout(0.2) %>%
layer_dense(1, activation = "sigmoid")
summary(discriminator)Create the generator
It mirrors the discriminator, replacing conv_2d layers with conv_2d_transpose layers.
latent_dim <- 128L
generator <- keras_model_sequential(input_shape = shape(latent_dim), name = "generator") %>%
layer_dense(8 * 8 * 128) %>%
layer_reshape(shape(8, 8, 128)) %>%
layer_conv_2d_transpose(128, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d_transpose(256, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d_transpose(512, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d(3, kernel_size = 5, padding = "same", activation = "sigmoid")
summary(generator)Override train_step
gan <- new_model_class(
"gan",
initialize = function(discriminator, generator, latent_dim) {
super()$`__init__`()
self$discriminator <- discriminator
self$generator <- generator
self$latent_dim <- latent_dim
self$rescale <- layer_rescaling(scale = 1/255)
},
compile = function(d_optimizer, g_optimizer, loss_fn) {
super()$compile()
self$d_optimizer <- d_optimizer
self$g_optimizer <- g_optimizer
self$loss_fn <- loss_fn
self$d_loss_metric <- tf$keras$metrics$Mean(name = "d_loss")
self$g_loss_metric <- keras$metrics$Mean(name = "g_loss")
},
metrics = mark_active(function() {
list(self$d_loss_metric, self$g_loss_metric)
}),
train_step = function(real_images) {
real_images <- self$rescale(real_images)
# Sample random points in the latent space
batch_size <- tf$shape(real_images)[1]
random_latent_vectors <- tf$random$normal(
shape = reticulate::tuple(batch_size, self$latent_dim)
)
# Decode them to fake images
generated_images <- self$generator(random_latent_vectors)
# Combine them with real images
combined_images <- tf$concat(list(generated_images, real_images), axis = 0L)
# Assemble labels discriminating real from fake images
labels <- tf$concat(
list(
tf$ones(reticulate::tuple(batch_size, 1L)),
tf$zeros(reticulate::tuple(batch_size, 1L))
),
axis = 0L
)
# Add random noise to the labels - important trick!
labels <- labels + 0.05 * tf$random$uniform(tf$shape(labels))
# Train the discriminator
with(tf$GradientTape() %as% tape, {
predictions <- self$discriminator(combined_images)
d_loss <- self$loss_fn(labels, predictions)
})
grads <- tape$gradient(d_loss, self$discriminator$trainable_weights)
self$d_optimizer$apply_gradients(
zip_lists(grads, self$discriminator$trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors <- tf$random$normal(
shape = reticulate::tuple(batch_size, self$latent_dim)
)
# Assemble labels that say "all real images"
misleading_labels <- tf$zeros(reticulate::tuple(batch_size, 1L))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with(tf$GradientTape() %as% tape, {
predictions <- self$discriminator(self$generator(random_latent_vectors))
g_loss <- self$loss_fn(misleading_labels, predictions)
})
grads <- tape$gradient(g_loss, self$generator$trainable_weights)
self$g_optimizer$apply_gradients(zip_lists(grads, self$generator$trainable_weights))
# Update metrics
self$d_loss_metric$update_state(d_loss)
self$g_loss_metric$update_state(g_loss)
list(
"d_loss" = self$d_loss_metric$result(),
"g_loss" = self$g_loss_metric$result()
)
}
)Create a callback that periodically saves generated images
gan_monitor <- new_callback_class(
"gan_monitor",
initialize = function(num_img = 3, latent_dim = 128L) {
self$num_img <- num_img
self$latent_dim <- as.integer(latent_dim)
if (!fs::dir_exists("dcgan")) fs::dir_create("dcgan")
},
on_epoch_end = function(epoch, logs) {
random_latent_vectors <- tf$random$normal(shape = shape(self$num_img, self$latent_dim))
generated_images <- self$model$generator(random_latent_vectors)
generated_images <- tf$clip_by_value(generated_images * 255, 0, 255)
generated_images <- as.array(generated_images)
for (i in seq_len(self$num_img)) {
image_array_save(
generated_images[i,,,],
sprintf("dcgan/generated_img_%03d_%d.png", epoch, i),
scale = FALSE
)
}
}
)Train the end-to-end model
epochs <- 15 # In practice, use ~100 epochs
gan <- gan(discriminator = discriminator, generator = generator, latent_dim = latent_dim)
gan %>% compile(
d_optimizer = optimizer_adam(learning_rate = 1e-4),
g_optimizer = optimizer_adam(learning_rate = 1e-4),
loss_fn = loss_binary_crossentropy(),
)
gan %>% fit(
dataset,
epochs = epochs,
callbacks = list(
gan_monitor(num_img = 10, latent_dim = latent_dim)
)
)Some of the last generated images around epoch 15 - each row is an epoch. (results keep improving after that):
grid <- expand.grid(1:10, 0:14)
knitr::include_graphics(sprintf("dcgan/generated_img_%03d_%d.png", grid[[2]], grid[[1]]))