library(tensorflow)
library(keras)
stop("STOP")
DCGAN to generate face images
generative
A simple DCGAN trained using
fit()
by overriding train_step
on CelebA images.
Setup
Prepare CelebA data
We’ll use face images from the CelebA dataset, resized to 64x64.
<- fs::path("~/datasets/celeba_gan")
dataset_path # output <- "celeba_gan/"
if (!fs::dir_exists(dataset_path)) {
::dir_create(fs::path_dir(dataset_path))
fs<- "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
url ::import("gdown")$download(url, output, quiet = TRUE)
reticulateunzip(output, exdir = fs::path_dir(output))
}
<- fs::path(fs::path_dir(output), "img_align_celeba") dataset_path
Create a dataset from our folder:
<- image_dataset_from_directory(
dataset
dataset_path, image_size = c(64, 64),
label_mode = NULL,
batch_size = 32
)<- dataset$apply(tf$data$experimental$ignore_errors(
dataset log_warning=FALSE
))
Let’s display a sample image:
%>%
dataset ::as_iterator() %>%
reticulate::iter_next() %>%
reticulateas.array() %>%
1,,,]} %>%
{.[as.raster(max = 255) %>%
plot()
Create the discriminator
It maps a 64x64 image to a binary classification score.
<- keras_model_sequential(name = "discriminator", input_shape = shape(64, 64, 3)) %>%
discriminator layer_conv_2d(64, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d(128, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d(128, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_flatten() %>%
layer_dropout(0.2) %>%
layer_dense(1, activation = "sigmoid")
summary(discriminator)
Create the generator
It mirrors the discriminator, replacing conv_2d
layers with conv_2d_transpose
layers.
<- 128L
latent_dim
<- keras_model_sequential(input_shape = shape(latent_dim), name = "generator") %>%
generator layer_dense(8 * 8 * 128) %>%
layer_reshape(shape(8, 8, 128)) %>%
layer_conv_2d_transpose(128, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d_transpose(256, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d_transpose(512, kernel_size = 4, strides = 2, padding = "same") %>%
layer_activation_leaky_relu(alpha = 0.2) %>%
layer_conv_2d(3, kernel_size = 5, padding = "same", activation = "sigmoid")
summary(generator)
Override train_step
<- new_model_class(
gan "gan",
initialize = function(discriminator, generator, latent_dim) {
super()$`__init__`()
$discriminator <- discriminator
self$generator <- generator
self$latent_dim <- latent_dim
self$rescale <- layer_rescaling(scale = 1/255)
self
},compile = function(d_optimizer, g_optimizer, loss_fn) {
super()$compile()
$d_optimizer <- d_optimizer
self$g_optimizer <- g_optimizer
self$loss_fn <- loss_fn
self$d_loss_metric <- tf$keras$metrics$Mean(name = "d_loss")
self$g_loss_metric <- keras$metrics$Mean(name = "g_loss")
self
},metrics = mark_active(function() {
list(self$d_loss_metric, self$g_loss_metric)
}),train_step = function(real_images) {
<- self$rescale(real_images)
real_images
# Sample random points in the latent space
<- tf$shape(real_images)[1]
batch_size <- tf$random$normal(
random_latent_vectors shape = reticulate::tuple(batch_size, self$latent_dim)
)
# Decode them to fake images
<- self$generator(random_latent_vectors)
generated_images
# Combine them with real images
<- tf$concat(list(generated_images, real_images), axis = 0L)
combined_images
# Assemble labels discriminating real from fake images
<- tf$concat(
labels list(
$ones(reticulate::tuple(batch_size, 1L)),
tf$zeros(reticulate::tuple(batch_size, 1L))
tf
), axis = 0L
)# Add random noise to the labels - important trick!
<- labels + 0.05 * tf$random$uniform(tf$shape(labels))
labels
# Train the discriminator
with(tf$GradientTape() %as% tape, {
<- self$discriminator(combined_images)
predictions <- self$loss_fn(labels, predictions)
d_loss
})
<- tape$gradient(d_loss, self$discriminator$trainable_weights)
grads $d_optimizer$apply_gradients(
selfzip_lists(grads, self$discriminator$trainable_weights)
)
# Sample random points in the latent space
<- tf$random$normal(
random_latent_vectors shape = reticulate::tuple(batch_size, self$latent_dim)
)
# Assemble labels that say "all real images"
<- tf$zeros(reticulate::tuple(batch_size, 1L))
misleading_labels
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with(tf$GradientTape() %as% tape, {
<- self$discriminator(self$generator(random_latent_vectors))
predictions <- self$loss_fn(misleading_labels, predictions)
g_loss
})<- tape$gradient(g_loss, self$generator$trainable_weights)
grads $g_optimizer$apply_gradients(zip_lists(grads, self$generator$trainable_weights))
self
# Update metrics
$d_loss_metric$update_state(d_loss)
self$g_loss_metric$update_state(g_loss)
selflist(
"d_loss" = self$d_loss_metric$result(),
"g_loss" = self$g_loss_metric$result()
)
} )
Create a callback that periodically saves generated images
<- new_callback_class(
gan_monitor "gan_monitor",
initialize = function(num_img = 3, latent_dim = 128L) {
$num_img <- num_img
self$latent_dim <- as.integer(latent_dim)
selfif (!fs::dir_exists("dcgan")) fs::dir_create("dcgan")
},on_epoch_end = function(epoch, logs) {
<- tf$random$normal(shape = shape(self$num_img, self$latent_dim))
random_latent_vectors <- self$model$generator(random_latent_vectors)
generated_images <- tf$clip_by_value(generated_images * 255, 0, 255)
generated_images <- as.array(generated_images)
generated_images for (i in seq_len(self$num_img)) {
image_array_save(
generated_images[i,,,], sprintf("dcgan/generated_img_%03d_%d.png", epoch, i),
scale = FALSE
)
}
} )
Train the end-to-end model
<- 15 # In practice, use ~100 epochs
epochs
<- gan(discriminator = discriminator, generator = generator, latent_dim = latent_dim)
gan %>% compile(
gan d_optimizer = optimizer_adam(learning_rate = 1e-4),
g_optimizer = optimizer_adam(learning_rate = 1e-4),
loss_fn = loss_binary_crossentropy(),
)
%>% fit(
gan
dataset, epochs = epochs,
callbacks = list(
gan_monitor(num_img = 10, latent_dim = latent_dim)
) )
Some of the last generated images around epoch 15 - each row is an epoch. (results keep improving after that):
<- expand.grid(1:10, 0:14)
grid ::include_graphics(sprintf("dcgan/generated_img_%03d_%d.png", grid[[2]], grid[[1]])) knitr