library(tensorflow)
library(keras)
The Sequential model
Setup
When to use a Sequential model
A Sequential
model is appropriate for a plain stack of layers where each layer has exactly one input tensor and one output tensor.
Schematically, the following Sequential
model:
# Define Sequential model with 3 layers
<- keras_model_sequential() %>%
model layer_dense(2, activation = "relu", name = "layer1") %>%
layer_dense(3, activation = "relu", name = "layer2") %>%
layer_dense(4, name = "layer3")
# Call model on a test input
<- tf$ones(shape(3, 3))
x <- model(x) y
is equivalent to this function:
# Create 3 layers
<- layer_dense(units = 2, activation = "relu", name = "layer1")
layer1 <- layer_dense(units = 3, activation = "relu", name = "layer2")
layer2 <- layer_dense(units = 4, name = "layer3")
layer3
# Call layers on a test input
<- tf$ones(shape(3, 3))
x <- layer3(layer2(layer1(x))) y
A Sequential model is not appropriate when:
- Your model has multiple inputs or multiple outputs
- Any of your layers has multiple inputs or multiple outputs
- You need to do layer sharing
- You want non-linear topology (e.g. a residual connection, a multi-branch model)
Creating a Sequential model
You can create a Sequential model by piping a model through a series layers.
<- keras_model_sequential() %>%
model layer_dense(2, activation = "relu") %>%
layer_dense(3, activation = "relu") %>%
layer_dense(4)
Its layers are accessible via the layers
attribute:
$layers model
[[1]]
<keras.src.layers.core.dense.Dense object at 0x7efc088a7280>
[[2]]
<keras.src.layers.core.dense.Dense object at 0x7efc088a6770>
[[3]]
<keras.src.layers.core.dense.Dense object at 0x7efc09c9ff40>
You can also create a Sequential model incrementally:
<- keras_model_sequential()
model %>% layer_dense(2, activation = "relu")
model %>% layer_dense(3, activation = "relu")
model %>% layer_dense(4) model
Note that there’s also a corresponding pop()
method to remove layers: a Sequential model behaves very much like a stack of layers.
%>% pop_layer()
model length(model$layers) # 2
[1] 2
Also note that the Sequential constructor accepts a name
argument, just like any layer or model in Keras. This is useful to annotate TensorBoard graphs with semantically meaningful names.
<- keras_model_sequential(name = "my_sequential")
model %>% layer_dense(2, activation = "relu", name = "layer1")
model %>% layer_dense(3, activation = "relu", name = "layer2")
model %>% layer_dense(4, name = "layer3") model
Specifying the input shape in advance
Generally, all layers in Keras need to know the shape of their inputs in order to be able to create their weights. So when you create a layer like this, initially, it has no weights:
<- layer_dense(units = 3)
layer $weights # Empty layer
list()
It creates its weights the first time it is called on an input, since the shape of the weights depends on the shape of the inputs:
# Call layer on a test input
<- tf$ones(shape(1, 4))
x <- layer(x)
y $weights # Now it has weights, of shape (4, 3) and (3,) layer
[[1]]
<tf.Variable 'dense_6/kernel:0' shape=(4, 3) dtype=float32, numpy=
array([[ 0.8861673 , 0.7246642 , 0.29001963],
[ 0.2592212 , 0.8790369 , 0.43319035],
[ 0.7657926 , -0.67766863, 0.00220335],
[ 0.53810084, -0.4877279 , 0.55877864]], dtype=float32)>
[[2]]
<tf.Variable 'dense_6/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>
Naturally, this also applies to Sequential models. When you instantiate a Sequential model without an input shape, it isn’t “built”: it has no weights (and calling model$weights
results in an error stating just this). The weights are created when the model first sees some input data:
<- keras_model_sequential() %>%
model layer_dense(2, activation = "relu") %>%
layer_dense(3, activation = "relu") %>%
layer_dense(4)
# No weights at this stage!
# At this point, you can't do this:
try(model$weights)
Error in py_get_attr_impl(x, name, silent) :
ValueError: Weights for model 'sequential_3' have not yet been created. Weights are created when the model is first called on inputs or `build()` is called with an `input_shape`.
Run `reticulate::py_last_error()` for details.
# The model summary is also not available:
summary(model)
Model: <no summary available, model was not built>
# Call the model on a test input
<- tf$ones(shape(1, 4))
x <- model(x)
y cat("Number of weights after calling the model:", length(model$weights), "\n") # 6
Number of weights after calling the model: 6
Once a model is “built”, you can call its summary()
method to display its contents (the summary()
method is also called by the default print()
method:
summary(model)
Model: "sequential_3"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
dense_9 (Dense) (1, 2) 10
dense_8 (Dense) (1, 3) 9
dense_7 (Dense) (1, 4) 16
============================================================================
Total params: 35 (140.00 Byte)
Trainable params: 35 (140.00 Byte)
Non-trainable params: 0 (0.00 Byte)
____________________________________________________________________________
However, it can be very useful when building a Sequential model incrementally to be able to display the summary of the model so far, including the current output shape. In this case, you should start your model by passing an input_shape
argument to your model, so that it knows its input shape from the start:
<- keras_model_sequential(input_shape = c(4))
model %>% layer_dense(2, activation = "relu")
model
model
Model: "sequential_4"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
dense_10 (Dense) (None, 2) 10
============================================================================
Total params: 10 (40.00 Byte)
Trainable params: 10 (40.00 Byte)
Non-trainable params: 0 (0.00 Byte)
____________________________________________________________________________
Models built with a predefined input shape like this always have weights (even before seeing any data) and always have a defined output shape.
In general, it’s a recommended best practice to always specify the input shape of a Sequential model in advance if you know what it is.
A common debugging workflow: %>%
+ summary()
When building a new Sequential architecture, it’s useful to incrementally stack layers and print model summaries. For instance, this enables you to monitor how a stack of Conv2D
and MaxPooling2D
layers is downsampling image feature maps:
<- keras_model_sequential(input_shape = c(250, 250, 3)) # 250x250 RGB images
model
%>%
model layer_conv_2d(32, 5, strides = 2, activation = "relu") %>%
layer_conv_2d(32, 3, activation = "relu") %>%
layer_max_pooling_2d(3)
# Can you guess what the current output shape is at this point? Probably not.
# Let's just print it:
model
Model: "sequential_5"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
conv2d_1 (Conv2D) (None, 123, 123, 32) 2432
conv2d (Conv2D) (None, 121, 121, 32) 9248
max_pooling2d (MaxPooling2D) (None, 40, 40, 32) 0
============================================================================
Total params: 11680 (45.62 KB)
Trainable params: 11680 (45.62 KB)
Non-trainable params: 0 (0.00 Byte)
____________________________________________________________________________
# The answer was: (40, 40, 32), so we can keep downsampling...
%>%
model layer_conv_2d(32, 3, activation = "relu") %>%
layer_conv_2d(32, 3, activation = "relu") %>%
layer_max_pooling_2d(3) %>%
layer_conv_2d(32, 3, activation = "relu") %>%
layer_conv_2d(32, 3, activation = "relu") %>%
layer_max_pooling_2d(2)
# And now?
model
Model: "sequential_5"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
conv2d_1 (Conv2D) (None, 123, 123, 32) 2432
conv2d (Conv2D) (None, 121, 121, 32) 9248
max_pooling2d (MaxPooling2D) (None, 40, 40, 32) 0
conv2d_5 (Conv2D) (None, 38, 38, 32) 9248
conv2d_4 (Conv2D) (None, 36, 36, 32) 9248
max_pooling2d_2 (MaxPooling2D) (None, 12, 12, 32) 0
conv2d_3 (Conv2D) (None, 10, 10, 32) 9248
conv2d_2 (Conv2D) (None, 8, 8, 32) 9248
max_pooling2d_1 (MaxPooling2D) (None, 4, 4, 32) 0
============================================================================
Total params: 48672 (190.12 KB)
Trainable params: 48672 (190.12 KB)
Non-trainable params: 0 (0.00 Byte)
____________________________________________________________________________
# Now that we have 4x4 feature maps, time to apply global max pooling.
%>% layer_global_max_pooling_2d()
model
# Finally, we add a classification layer.
%>% layer_dense(10) model
Very practical, right?
What to do once you have a model
Once your model architecture is ready, you will want to:
- Train your model, evaluate it, and run inference. See our guide to training & evaluation with the built-in loops
- Save your model to disk and restore it. See our guide to serialization & saving.
- Speed up model training by leveraging multiple GPUs. See our guide to multi-GPU and distributed training.
Feature extraction with a Sequential model
Once a Sequential model has been built, it behaves like a Functional API model. This means that every layer has an input
and output
attribute. These attributes can be used to do neat things, like quickly creating a model that extracts the outputs of all intermediate layers in a Sequential model:
<-
initial_model keras_model_sequential(input_shape = c(250, 250, 3)) %>%
layer_conv_2d(32, 5, strides = 2, activation = "relu") %>%
layer_conv_2d(32, 3, activation = "relu") %>%
layer_conv_2d(32, 3, activation = "relu")
<- keras_model(
feature_extractor inputs = initial_model$inputs,
outputs = lapply(initial_model$layers, \(layer) layer$output)
)
# Call feature extractor on test input.
<- tf$ones(shape(1, 250, 250, 3))
x <- feature_extractor(x) features
Here’s a similar example that only extract features from one layer:
<-
initial_model keras_model_sequential(input_shape = c(250, 250, 3)) %>%
layer_conv_2d(32, 5, strides = 2, activation = "relu") %>%
layer_conv_2d(32, 3, activation = "relu", name = "my_intermediate_layer") %>%
layer_conv_2d(32, 3, activation = "relu")
<- keras_model(
feature_extractor inputs = initial_model$inputs,
outputs = get_layer(initial_model, name = "my_intermediate_layer")$output
)
# Call feature extractor on test input.
<- tf$ones(shape(1, 250, 250, 3))
x <- feature_extractor(x) features
Transfer learning with a Sequential model
Transfer learning consists of freezing the bottom layers in a model and only training the top layers. If you aren’t familiar with it, make sure to read our guide to transfer learning.
Here are two common transfer learning blueprint involving Sequential models.
First, let’s say that you have a Sequential model, and you want to freeze all layers except the last one. In this case, you would simply iterate over model$layers
and set layer$trainable = FALSE
on each layer, except the last one. Like this:
<- keras_model_sequential(input_shape = c(784)) %>%
model layer_dense(32, activation = 'relu') %>%
layer_dense(32, activation = 'relu') %>%
layer_dense(32, activation = 'relu') %>%
layer_dense(10)
# Presumably you would want to first load pre-trained weights.
$load_weights(...)
model
# Freeze all layers except the last one.
for (layer in head(model$layers, -1))
$trainable <- FALSE
layer
# can also just call: freeze_weights(model, to = -2)
# Recompile and train (this will only update the weights of the last layer).
%>% compile(...)
model %>% fit(...) model
Another common blueprint is to use a Sequential model to stack a pre-trained model and some freshly initialized classification layers. Like this:
Load a convolutional base with pre-trained weights
<- application_xception(
base_model weights = 'imagenet',
include_top = FALSE,
pooling = 'avg')
# Freeze the base model
$trainable <- FALSE
base_model
# Use a Sequential model to add a trainable classifier on top
<- keras_model_sequential() %>%
model base_model() %>%
layer_dense(1000)
# Compile & train
%>% compile(...)
model %>% fit(...) model
If you do transfer learning, you will probably find yourself frequently using these two patterns.
That’s about all you need to know about Sequential models!
To find out more about building models in Keras, see:
Environment Details
::tf_config() tensorflow
TensorFlow v2.13.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.3.1 (2023-06-16)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so; LAPACK version 3.10.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
time zone: America/New_York
tzcode source: system (glibc)
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] keras_2.13.0.9000 tensorflow_2.13.0.9000
loaded via a namespace (and not attached):
[1] vctrs_0.6.3 cli_3.6.1 knitr_1.43
[4] zeallot_0.1.0 rlang_1.1.1 xfun_0.40
[7] png_0.1-8 generics_0.1.3 jsonlite_1.8.7
[10] glue_1.6.2 htmltools_0.5.6 fansi_1.0.4
[13] rmarkdown_2.24 grid_4.3.1 tfruns_1.5.1
[16] evaluate_0.21 tibble_3.2.1 base64enc_0.1-3
[19] fastmap_1.1.1 yaml_2.3.7 lifecycle_1.0.3
[22] whisker_0.4.1 compiler_4.3.1 htmlwidgets_1.6.2
[25] Rcpp_1.0.11 pkgconfig_2.0.3 rstudioapi_0.15.0
[28] lattice_0.21-8 digest_0.6.33 R6_2.5.1
[31] reticulate_1.31.0.9000 utf8_1.2.3 pillar_1.9.0
[34] magrittr_2.0.3 Matrix_1.5-4.1 tools_4.3.1
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.4.0
array-record==0.4.1
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
bleach==6.0.0
cachetools==5.3.1
certifi==2023.7.22
charset-normalizer==3.2.0
click==8.1.7
decorator==5.1.1
dm-tree==0.1.8
etils==1.4.1
executing==1.2.0
flatbuffers==23.5.26
gast==0.4.0
google-auth==2.22.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
googleapis-common-protos==1.60.0
grpcio==1.57.0
h5py==3.9.0
idna==3.4
importlib-resources==6.0.1
ipython==8.14.0
jedi==0.19.0
kaggle==1.5.16
keras==2.13.1
keras-tuner==1.3.5
kt-legacy==1.0.5
libclang==16.0.6
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
numpy==1.24.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cudnn-cu11==8.6.0.163
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.1
pandas==2.0.3
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.0.0
promise==2.3
prompt-toolkit==3.0.39
protobuf==3.20.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pydot==1.4.2
Pygments==2.16.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-slugify==8.0.1
pytz==2023.3
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.11.2
six==1.16.0
stack-data==0.6.2
tensorboard==2.13.0
tensorboard-data-server==0.7.1
tensorflow==2.13.0
tensorflow-datasets==4.9.2
tensorflow-estimator==2.13.0
tensorflow-hub==0.14.0
tensorflow-io-gcs-filesystem==0.33.0
tensorflow-metadata==1.14.0
termcolor==2.3.0
text-unidecode==1.3
toml==0.10.2
tqdm==4.66.1
traitlets==5.9.0
typing_extensions==4.5.0
tzdata==2023.3
urllib3==1.26.16
wcwidth==0.2.6
webencodings==0.5.1
Werkzeug==2.3.7
wrapt==1.15.0
zipp==3.16.2
TF Devices:
- PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
- PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
CPU cores: 12
Date rendered: 2023-08-28
Page render time: 4 seconds