library(tensorflow)
library(keras)
<- get_file("sky.jpg", "https://i.imgur.com/aGBdQyK.jpg")
base_image_path <- "sky_dream"
result_prefix
# These are the names of the layers
# for which we try to maximize activation,
# as well as their weight in the final loss
# we try to maximize.
# You can tweak these setting to obtain new visual effects.
<- list(
layer_settings "mixed4" = 1.0,
"mixed5" = 1.5,
"mixed6" = 2.0,
"mixed7" = 2.5
)
# Playing with these hyperparameters will also allow you to achieve new effects
<- 0.01 # Gradient ascent step size
step <- 3 # Number of scales at which to run gradient ascent
num_octave <- 1.4 # Size ratio between scales
octave_scale <- 20 # Number of ascent steps per scale
iterations <- 15.0 max_loss
Deep Dream
generative
Generating Deep Dreams with Keras.
Introduction
“Deep dream” is an image-filtering technique which consists of taking an image classification model, and running gradient ascent over an input image to try to maximize the activations of specific layers (and sometimes, specific units in specific layers) for this input. It produces hallucination-like visuals.
It was first introduced by Alexander Mordvintsev from Google in July 2015.
Process:
- Load the original image.
- Define a number of processing scales (“octaves”), from smallest to largest.
- Resize the original image to the smallest scale.
- For every scale, starting with(the smallest (i$e. current one), { })
- Run gradient ascent
- Upscale image to the next scale
- Reinject the detail that was lost at upscaling time
- Stop when we are back to the original size. To obtain the detail lost during upscaling, we simply take the original image, shrink it down, upscale it, and compare the result to the (resized) original image.
Setup
This is our base image:
<- function(img) {
plot_image %>%
img as.raster(max = 255) %>%
plot()
}
%>%
base_image_path image_load() %>%
image_to_array() %>%
plot_image()
Let’s set up some image preprocessing/deprocessing utilities:
<- function(image_path) {
preprocess_image # Util function to open, resize and format pictures
# into appropriate arrays.
<- image_path %>%
img image_load() %>%
image_to_array()
dim(img) <- c(1, dim(img))
inception_v3_preprocess_input(img)
}
<- function(x) {
deprocess_image dim(x) <- dim(x)[-1]
# Undo inception v3 preprocessing
<- x/2.0
x <- x + 0.5
x <- x*255.0
x <- raster::clamp(as.numeric(x), 0, 255)
x[]
x }
Compute the Deep Dream loss
First, build a feature extraction model to retrieve the activations of our target layers given an input image.
# Build an InceptionV3 model loaded with pre-trained ImageNet weights
<- application_inception_v3(weights = "imagenet", include_top = FALSE)
model
# Get the symbolic outputs of each "key" layer (we gave them unique names).
<- purrr::imap(layer_settings, function(v, name) {
outputs_dict <- get_layer(model, name)
layer $output
layer
})
# Set up a model that returns the activation values for every target layer
# (as a dict)
<- keras_model(inputs = model$inputs, outputs = outputs_dict) feature_extractor
The actual loss computation is very simple:
<- function(input_image) {
compute_loss <- feature_extractor(input_image)
features # Initialize the loss
<- tf$zeros(shape = shape())
loss
%>%
layer_settings ::imap(function(coeff, name) {
purrr<- features[[name]]
activation <- tf$reduce_prod(tf$cast(tf$shape(activation), "float32"))
scaling # We avoid border artifacts by only involving non-border pixels in the loss.
* tf$reduce_sum(tf$square(activation[, 3:-2, 3:-2, ])) / scaling
coeff %>%
}) ::reduce(tf$add)
purrr }
Set up the gradient ascent loop for one octave
<- tf_function(function(img, learning_rate) {
gradient_ascent_step with(tf$GradientTape() %as% tape, {
$watch(img)
tape<- compute_loss(img)
loss
})
# Compute gradients.
<- tape$gradient(loss, img)
grads # Normalize gradients.
<- grads/tf$maximum(tf$reduce_mean(tf$abs(grads)), 1e-6)
grads <- img + learning_rate * grads
img list(loss, img)
})
<- function(img, iterations, learning_rate, max_loss = NULL) {
gradient_ascent_loop for (i in seq_len(iterations)) {
c(loss, img) %<-% gradient_ascent_step(img, learning_rate)
if (!is.null(max_loss) && as.logical(loss > max_loss)) {
break
}cat("... Loss value at step ", i, ": ", as.numeric(loss), "\n")
}
img }
Run the training loop, iterating over different octaves
<- preprocess_image(base_image_path)
original_img <- dim(original_img)[2:3]
original_shape
<- list(original_shape)
successive_shapes for (i in seq_len(num_octave - 1)) {
<- as.integer(original_shape / octave_scale^i)
shape +1]] <- shape
successive_shapes[[i
}<- rev(successive_shapes)
successive_shapes
<- tf$image$resize(original_img, successive_shapes[[1]])
shrunk_original_img <- tf$identity(original_img) # Make a copy
img for (i in seq_along(successive_shapes)) {
<- successive_shapes[[i]]
shape
cat("Processing octave ", i, "with shape:", shape, "\n")
<- tf$image$resize(img, shape)
img <- gradient_ascent_loop(
img iterations = iterations, learning_rate = step, max_loss = max_loss
img,
)<- tf$image$resize(shrunk_original_img, shape)
upscaled_shrunk_original_img <- tf$image$resize(original_img, shape)
same_size_original <- same_size_original - upscaled_shrunk_original_img
lost_detail
<- img + lost_detail
img <- tf$image$resize(original_img, shape)
shrunk_original_img }
Display the result.
%>%
img as.array() %>%
deprocess_image() %>%
plot_image()