library(tensorflow)
library(keras)
Working with preprocessing layers
Keras preprocessing
The Keras preprocessing layers API allows developers to build Keras-native input processing pipelines. These input processing pipelines can be used as independent preprocessing code in non-Keras workflows, combined directly with Keras models, and exported as part of a Keras SavedModel.
With Keras preprocessing layers, you can build and export models that are truly end-to-end: models that accept raw images or raw structured data as input; models that handle feature normalization or feature value indexing on their own.
Available preprocessing layers
Text preprocessing
layer_text_vectorization()
: turns raw strings into an encoded representation that can be read by alayer_embedding()
orlayer_dense()
layer.
Numerical features preprocessing
layer_normalization()
: performs feature-wise normalization of input features.layer_discretization()
: turns continuous numerical features into integer categorical features.
Categorical features preprocessing
layer_category_encoding()
: turns integer categorical features into one-hot, multi-hot, or count-based, dense representations.layer_hashing()
: performs categorical feature hashing, also known as the “hashing trick”.layer_string_lookup()
: turns string categorical values into an encoded representation that can be read by anEmbedding
layer orDense
layer.layer_integer_lookup()
: turns integer categorical values into an encoded representation that can be read by anEmbedding
layer orDense
layer.
Image preprocessing
These layers are for standardizing the inputs of an image model.
layer_resizing()
: resizes a batch of images to a target size.layer_rescaling()
: rescales and offsets the values of a batch of images (e.g., going from inputs in the[0, 255]
range to inputs in the[0, 1]
range.layer_center_crop()
: returns a center crop of a batch of images.
Image data augmentation
These layers apply random augmentation transforms to a batch of images. They are only active during training.
layer_random_crop()
layer_random_flip()
layer_random_flip()
layer_random_translation()
layer_random_rotation()
layer_random_zoom()
layer_random_height()
layer_random_width()
layer_random_contrast()
The adapt()
function
Some preprocessing layers have an internal state that can be computed based on a sample of the training data. The list of stateful preprocessing layers is:
layer_text_vectorization()
: holds a mapping between string tokens and integer indiceslayer_string_lookup()
andlayer_integer_lookup()
: hold a mapping between input values and integer indices.layer_normalization()
: holds the mean and standard deviation of the features.layer_discretization()
: holds information about value bucket boundaries.
Crucially, these layers are non-trainable. Their state is not set during training; it must be set before training, either by initializing them from a precomputed constant, or by “adapting” them on data.
You set the state of a preprocessing layer by exposing it to training data, via adapt()
:
<- rbind(c(0.1, 0.2, 0.3),
data c(0.8, 0.9, 1.0),
c(1.5, 1.6, 1.7))
<- layer_normalization()
layer adapt(layer, data)
<- as.array(layer(data))
normalized_data
sprintf("Features mean: %.2f", mean(normalized_data))
[1] "Features mean: -0.00"
sprintf("Features std: %.2f", sd(normalized_data))
[1] "Features std: 1.06"
adapt()
takes either an array or a tf_dataset
. In the case of layer_string_lookup()
and layer_text_vectorization()
, you can also pass a character vector:
<- c(
data "Congratulations!",
"Today is your day.",
"You're off to Great Places!",
"You're off and away!",
"You have brains in your head.",
"You have feet in your shoes.",
"You can steer yourself",
"any direction you choose.",
"You're on your own. And you know what you know.",
"And YOU are the one who'll decide where to go."
)
= layer_text_vectorization()
layer %>% adapt(data)
layer <- layer(data)
vectorized_text print(vectorized_text)
tf.Tensor(
[[31 0 0 0 0 0 0 0 0 0]
[15 23 3 30 0 0 0 0 0 0]
[ 4 7 6 25 19 0 0 0 0 0]
[ 4 7 5 35 0 0 0 0 0 0]
[ 2 10 34 9 3 24 0 0 0 0]
[ 2 10 27 9 3 18 0 0 0 0]
[ 2 33 17 11 0 0 0 0 0 0]
[37 28 2 32 0 0 0 0 0 0]
[ 4 22 3 20 5 2 8 14 2 8]
[ 5 2 36 16 21 12 29 13 6 26]], shape=(10, 10), dtype=int64)
In addition, adaptable layers always expose an option to directly set state via constructor arguments or weight assignment. If the intended state values are known at layer construction time, or are calculated outside of the adapt()
call, they can be set without relying on the layer’s internal computation. For instance, if external vocabulary files for the layer_text_vectorization()
, layer_string_lookup()
, or layer_integer_lookup()
layers already exist, those can be loaded directly into the lookup tables by passing a path to the vocabulary file in the layer’s constructor arguments.
Here’s an example where we instantiate a layer_string_lookup()
layer with precomputed vocabulary:
<- c("a", "b", "c", "d")
vocab <- as_tensor(rbind(c("a", "c", "d"),
data c("d", "z", "b")))
<- layer_string_lookup(vocabulary=vocab)
layer <- layer(data)
vectorized_data print(vectorized_data)
tf.Tensor(
[[1 3 4]
[4 0 2]], shape=(2, 3), dtype=int64)
Preprocessing data before the model or inside the model
There are two ways you could be using preprocessing layers:
Option 1: Make them part of the model, like this:
<- layer_input(shape = input_shape)
input <- input %>%
output preprocessing_layer() %>%
rest_of_the_model()
<- keras_model(input, output) model
With this option, preprocessing will happen on device, synchronously with the rest of the model execution, meaning that it will benefit from GPU acceleration. If you’re training on GPU, this is the best option for the layer_normalization()
layer, and for all image preprocessing and data augmentation layers.
Option 2: apply it to your tf_dataset
, so as to obtain a dataset that yields batches of preprocessed data, like this:
library(tfdatasets)
<- ... # define dataset
dataset <- dataset %>%
dataset dataset_map(function(x, y) list(preprocessing_layer(x), y))
With this option, your preprocessing will happen on CPU, asynchronously, and will be buffered before going into the model. In addition, if you call tfdatasets::dataset_prefetch()
on your dataset, the preprocessing will happen efficiently in parallel with training:
<- dataset %>%
dataset dataset_map(function(x, y) list(preprocessing_layer(x), y)) %>%
dataset_prefetch()
%>% fit(dataset) model
This is the best option for layer_text_vectorization()
, and all structured data preprocessing layers. It can also be a good option if you’re training on CPU and you use image preprocessing layers.
Benefits of doing preprocessing inside the model at inference time
Even if you go with option 2, you may later want to export an inference-only end-to-end model that will include the preprocessing layers. The key benefit to doing this is that it makes your model portable and it helps reduce the training/serving skew.
When all data preprocessing is part of the model, other people can load and use your model without having to be aware of how each feature is expected to be encoded & normalized. Your inference model will be able to process raw images or raw structured data, and will not require users of the model to be aware of the details of e.g. the tokenization scheme used for text, the indexing scheme used for categorical features, whether image pixel values are normalized to [-1, +1]
or to [0, 1]
, etc. This is especially powerful if you’re exporting your model to another runtime, such as TensorFlow.js: you won’t have to reimplement your preprocessing pipeline in JavaScript.
If you initially put your preprocessing layers in your tf_dataset
pipeline, you can export an inference model that packages the preprocessing. Simply instantiate a new model that chains your preprocessing layers and your training model:
<- layer_input(shape = input_shape)
input <- input %>%
output preprocessing_layer(input) %>%
training_model()
<- keras_model(input, output) inference_model
Preprocessing during multi-worker training
Preprocessing layers are compatible with the tf.distribute API for running training across multiple machines.
In general, preprocessing layers should be placed inside a strategy$scope()
and called either inside or before the model as discussed above.
with(strategy$scope(), {
<- layer_input(shape=input_shape)
inputs <- layer_hashing(num_bins = 10)
preprocessing_layer <- layer_dense(units = 16)
dense_layer })
For more details, refer to the preprocessing section of the distributed input guide.
Quick recipes
Image data augmentation
Note that image data augmentation layers are only active during training (similar to the layer_dropout()
layer).
library(keras)
library(tfdatasets)
# Create a data augmentation stage with horizontal flipping, rotations, zooms
<-
data_augmentation keras_model_sequential() %>%
layer_random_flip("horizontal") %>%
layer_random_rotation(0.1) %>%
layer_random_zoom(0.1)
# Load some data
c(c(x_train, y_train), ...) %<-% dataset_cifar10()
<- dim(x_train)[-1] # drop batch dim
input_shape <- 10
classes
# Create a tf_dataset pipeline of augmented images (and their labels)
<- tensor_slices_dataset(list(x_train, y_train)) %>%
train_dataset dataset_batch(16) %>%
dataset_map( ~ list(data_augmentation(.x), .y)) # see ?purrr::map to learn about ~ notation
# Create a model and train it on the augmented image data
<- application_resnet50(weights = NULL,
resnet input_shape = input_shape,
classes = classes)
<- layer_input(shape = input_shape)
input <- input %>%
output layer_rescaling(1 / 255) %>% # Rescale inputs
resnet()
<- keras_model(input, output) %>%
model compile(optimizer = "rmsprop", loss = "sparse_categorical_crossentropy") %>%
fit(train_dataset, steps_per_epoch = 5)
Epoch 1/10
5/5 - 17s - loss: 8.5061 - 17s/epoch - 3s/step
Epoch 2/10
5/5 - 0s - loss: 7.1424 - 140ms/epoch - 28ms/step
Epoch 3/10
5/5 - 0s - loss: 3.9739 - 123ms/epoch - 25ms/step
Epoch 4/10
5/5 - 0s - loss: 3.4781 - 112ms/epoch - 22ms/step
Epoch 5/10
5/5 - 0s - loss: 3.6416 - 101ms/epoch - 20ms/step
Epoch 6/10
5/5 - 0s - loss: 3.1361 - 97ms/epoch - 19ms/step
Epoch 7/10
5/5 - 0s - loss: 2.9110 - 97ms/epoch - 19ms/step
Epoch 8/10
5/5 - 0s - loss: 3.4493 - 102ms/epoch - 20ms/step
Epoch 9/10
5/5 - 0s - loss: 3.4369 - 130ms/epoch - 26ms/step
Epoch 10/10
5/5 - 0s - loss: 3.2755 - 124ms/epoch - 25ms/step
You can see a similar setup in action in the example image classification from scratch.
Normalizing numerical features
library(tensorflow)
library(keras)
c(c(x_train, y_train), ...) %<-% dataset_cifar10()
<- x_train %>%
x_train array_reshape(c(dim(x_train)[1], -1L)) # flatten each case
<- dim(x_train)[-1] # keras layers automatically add the batch dim
input_shape <- 10
classes
# Create a layer_normalization() layer and set its internal state using the training data
<- layer_normalization()
normalizer %>% adapt(x_train)
normalizer
# Create a model that include the normalization layer
<- layer_input(shape = input_shape)
input <- input %>%
output normalizer() %>%
layer_dense(classes, activation = "softmax")
<- keras_model(input, output) %>%
model compile(optimizer = "adam",
loss = "sparse_categorical_crossentropy")
# Train the model
%>%
model fit(x_train, y_train)
Epoch 1/10
1563/1563 - 2s - loss: 2.1273 - 2s/epoch - 1ms/step
Epoch 2/10
1563/1563 - 2s - loss: 2.0445 - 2s/epoch - 1ms/step
Epoch 3/10
1563/1563 - 2s - loss: 2.0193 - 2s/epoch - 1ms/step
Epoch 4/10
1563/1563 - 2s - loss: 2.0165 - 2s/epoch - 1ms/step
Epoch 5/10
1563/1563 - 2s - loss: 2.0118 - 2s/epoch - 1ms/step
Epoch 6/10
1563/1563 - 2s - loss: 1.9881 - 2s/epoch - 1ms/step
Epoch 7/10
1563/1563 - 2s - loss: 1.9932 - 2s/epoch - 1ms/step
Epoch 8/10
1563/1563 - 2s - loss: 1.9925 - 2s/epoch - 1ms/step
Epoch 9/10
1563/1563 - 2s - loss: 1.9709 - 2s/epoch - 1ms/step
Epoch 10/10
1563/1563 - 2s - loss: 1.9737 - 2s/epoch - 1ms/step
Encoding string categorical features via one-hot encoding
# Define some toy data
<- as_tensor(c("a", "b", "c", "b", "c", "a")) %>%
data k_reshape(c(-1, 1)) # reshape into matrix with shape: (6, 1)
# Use layer_string_lookup() to build an index of
# the feature values and encode output.
<- layer_string_lookup(output_mode="one_hot")
lookup %>% adapt(data)
lookup
# Convert new test data (which includes unknown feature values)
= as_tensor(matrix(c("a", "b", "c", "d", "e", "")))
test_data = lookup(test_data)
encoded_data print(encoded_data)
tf.Tensor(
[[0. 0. 0. 1.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]
[1. 0. 0. 0.]], shape=(6, 4), dtype=float32)
Note that, here, index 0 is reserved for out-of-vocabulary values (values that were not seen during adapt()
).
You can see the layer_string_lookup()
in action in the Structured data classification from scratch example.
Encoding integer categorical features via one-hot encoding
# Define some toy data
<- as_tensor(matrix(c(10, 20, 20, 10, 30, 0)), "int32")
data
# Use layer_integer_lookup() to build an
# index of the feature values and encode output.
<- layer_integer_lookup(output_mode="one_hot")
lookup %>% adapt(data)
lookup
# Convert new test data (which includes unknown feature values)
<- as_tensor(matrix(c(10, 10, 20, 50, 60, 0)), "int32")
test_data <- lookup(test_data)
encoded_data print(encoded_data)
tf.Tensor(
[[0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 0. 1.]], shape=(6, 5), dtype=float32)
Note that index 0 is reserved for missing values (which you should specify as the value 0), and index 1 is reserved for out-of-vocabulary values (values that were not seen during adapt()
). You can configure this by using the mask_token
and oov_token
constructor arguments of layer_integer_lookup()
.
You can see the layer_integer_lookup()
in action in the example structured data classification from scratch.
Applying the hashing trick to an integer categorical feature
If you have a categorical feature that can take many different values (on the order of 10e3 or higher), where each value only appears a few times in the data, it becomes impractical and ineffective to index and one-hot encode the feature values. Instead, it can be a good idea to apply the “hashing trick”: hash the values to a vector of fixed size. This keeps the size of the feature space manageable, and removes the need for explicit indexing.
# Sample data: 10,000 random integers with values between 0 and 100,000
<- k_random_uniform(shape = c(10000, 1), dtype = "int64")
data
# Use the Hashing layer to hash the values to the range [0, 64]
<- layer_hashing(num_bins = 64, salt = 1337)
hasher
# Use the CategoryEncoding layer to multi-hot encode the hashed values
<- layer_category_encoding(num_tokens=64, output_mode="multi_hot")
encoder <- encoder(hasher(data))
encoded_data print(encoded_data$shape)
TensorShape([10000, 64])
Encoding text as a sequence of token indices
This is how you should preprocess text to be passed to an Embedding
layer.
library(tensorflow)
library(tfdatasets)
library(keras)
# Define some text data to adapt the layer
<- as_tensor(c(
adapt_data "The Brain is wider than the Sky",
"For put them side by side",
"The one the other will contain",
"With ease and You beside"
))
# Create a layer_text_vectorization() layer
<- layer_text_vectorization(output_mode="int")
text_vectorizer # Index the vocabulary via `adapt()`
%>% adapt(adapt_data)
text_vectorizer
# Try out the layer
cat("Encoded text:\n",
as.array(text_vectorizer("The Brain is deeper than the sea")))
Encoded text:
2 19 14 1 9 2 1
# Create a simple model
<- layer_input(shape(NULL), dtype="int64")
input
<- input %>%
output layer_embedding(input_dim = text_vectorizer$vocabulary_size(),
output_dim = 16) %>%
layer_gru(8) %>%
layer_dense(1)
<- keras_model(input, output)
model
# Create a labeled dataset (which includes unknown tokens)
<- tensor_slices_dataset(list(
train_dataset c("The Brain is deeper than the sea", "for if they are held Blue to Blue"),
c(1L, 0L)
))
# Preprocess the string inputs, turning them into int sequences
<- train_dataset %>%
train_dataset dataset_batch(2) %>%
dataset_map(~list(text_vectorizer(.x), .y))
# Train the model on the int sequences
cat("Training model...\n")
Training model...
%>%
model compile(optimizer = "rmsprop", loss = "mse") %>%
fit(train_dataset)
Epoch 1/10
1/1 - 1s - loss: 0.4794 - 1s/epoch - 1s/step
Epoch 2/10
1/1 - 0s - loss: 0.4483 - 5ms/epoch - 5ms/step
Epoch 3/10
1/1 - 0s - loss: 0.4271 - 5ms/epoch - 5ms/step
Epoch 4/10
1/1 - 0s - loss: 0.4101 - 4ms/epoch - 4ms/step
Epoch 5/10
1/1 - 0s - loss: 0.3955 - 5ms/epoch - 5ms/step
Epoch 6/10
1/1 - 0s - loss: 0.3824 - 4ms/epoch - 4ms/step
Epoch 7/10
1/1 - 0s - loss: 0.3705 - 4ms/epoch - 4ms/step
Epoch 8/10
1/1 - 0s - loss: 0.3596 - 5ms/epoch - 5ms/step
Epoch 9/10
1/1 - 0s - loss: 0.3493 - 4ms/epoch - 4ms/step
Epoch 10/10
1/1 - 0s - loss: 0.3395 - 4ms/epoch - 4ms/step
# For inference, you can export a model that accepts strings as input
<- layer_input(shape = 1, dtype="string")
input <- input %>%
output text_vectorizer() %>%
model()
<- keras_model(input, output)
end_to_end_model
# Call the end-to-end model on test data (which includes unknown tokens)
cat("Calling end-to-end model on test string...\n")
Calling end-to-end model on test string...
<- tf$constant(matrix("The one the other will absorb"))
test_data <- end_to_end_model(test_data)
test_output cat("Model output:", as.array(test_output), "\n")
Model output: 0.1588376
You can see the layer_text_vectorization()
layer in action, combined with an Embedding
mode, in the example text classification from scratch.
Note that when training such a model, for best performance, you should always use the layer_text_vectorization()
layer as part of the input pipeline.
Encoding text as a dense matrix of ngrams with multi-hot encoding
This is how you can preprocess text to be passed to a Dense
layer.
# Define some text data to adapt the layer
<- as_tensor(c(
adapt_data "The Brain is wider than the Sky",
"For put them side by side",
"The one the other will contain",
"With ease and You beside"
))
# Instantiate layer_text_vectorization() with "multi_hot" output_mode
# and ngrams=2 (index all bigrams)
= layer_text_vectorization(output_mode="multi_hot", ngrams=2)
text_vectorizer # Index the bigrams via `adapt()`
%>% adapt(adapt_data)
text_vectorizer
# Try out the layer
cat("Encoded text:\n",
as.array(text_vectorizer("The Brain is deeper than the sea")))
Encoded text:
1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 1 0 0 0
# Create a simple model
= layer_input(shape = text_vectorizer$vocabulary_size(), dtype="int64")
input
<- input %>%
output layer_dense(1)
<- keras_model(input, output)
model
# Create a labeled dataset (which includes unknown tokens)
= tensor_slices_dataset(list(
train_dataset c("The Brain is deeper than the sea", "for if they are held Blue to Blue"),
c(1L, 0L)
))
# Preprocess the string inputs, turning them into int sequences
<- train_dataset %>%
train_dataset dataset_batch(2) %>%
dataset_map(~list(text_vectorizer(.x), .y))
# Train the model on the int sequences
cat("Training model...\n")
Training model...
%>%
model compile(optimizer="rmsprop", loss="mse") %>%
fit(train_dataset)
Epoch 1/10
1/1 - 0s - loss: 0.5533 - 268ms/epoch - 268ms/step
Epoch 2/10
1/1 - 0s - loss: 0.5238 - 5ms/epoch - 5ms/step
Epoch 3/10
1/1 - 0s - loss: 0.5032 - 4ms/epoch - 4ms/step
Epoch 4/10
1/1 - 0s - loss: 0.4864 - 4ms/epoch - 4ms/step
Epoch 5/10
1/1 - 0s - loss: 0.4718 - 4ms/epoch - 4ms/step
Epoch 6/10
1/1 - 0s - loss: 0.4587 - 3ms/epoch - 3ms/step
Epoch 7/10
1/1 - 0s - loss: 0.4466 - 4ms/epoch - 4ms/step
Epoch 8/10
1/1 - 0s - loss: 0.4354 - 3ms/epoch - 3ms/step
Epoch 9/10
1/1 - 0s - loss: 0.4249 - 3ms/epoch - 3ms/step
Epoch 10/10
1/1 - 0s - loss: 0.4149 - 3ms/epoch - 3ms/step
# For inference, you can export a model that accepts strings as input
<- layer_input(shape = 1, dtype="string")
input
<- input %>%
output text_vectorizer() %>%
model()
= keras_model(input, output)
end_to_end_model
# Call the end-to-end model on test data (which includes unknown tokens)
cat("Calling end-to-end model on test string...\n")
Calling end-to-end model on test string...
<- tf$constant(matrix("The one the other will absorb"))
test_data <- end_to_end_model(test_data)
test_output cat("Model output: "); print(test_output); cat("\n")
Model output:
tf.Tensor([[0.67599183]], shape=(1, 1), dtype=float32)
Encoding text as a dense matrix of ngrams with TF-IDF weighting
This is an alternative way of preprocessing text before passing it to a layer_dense
layer.
# Define some text data to adapt the layer
<- as_tensor(c(
adapt_data "The Brain is wider than the Sky",
"For put them side by side",
"The one the other will contain",
"With ease and You beside"
))
# Instantiate layer_text_vectorization() with "tf-idf" output_mode
# (multi-hot with TF-IDF weighting) and ngrams=2 (index all bigrams)
= layer_text_vectorization(output_mode="tf-idf", ngrams=2)
text_vectorizer # Index the bigrams and learn the TF-IDF weights via `adapt()`
with(tf$device("CPU"), {
# A bug that prevents this from running on GPU for now.
%>% adapt(adapt_data)
text_vectorizer
})
# Try out the layer
cat("Encoded text:\n",
as.array(text_vectorizer("The Brain is deeper than the sea")))
Encoded text:
5.461647 1.694596 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1.098612 1.098612 1.098612 0 0 0 0 0 0 0 0 0 1.098612 0 0 0 0 0 0 0 1.098612 1.098612 0 0 0
# Create a simple model
<- layer_input(shape = text_vectorizer$vocabulary_size(), dtype="int64")
input <- input %>% layer_dense(1)
output <- keras_model(input, output)
model
# Create a labeled dataset (which includes unknown tokens)
= tensor_slices_dataset(list(
train_dataset c("The Brain is deeper than the sea", "for if they are held Blue to Blue"),
c(1L, 0L)
))
# Preprocess the string inputs, turning them into int sequences
<- train_dataset %>%
train_dataset dataset_batch(2) %>%
dataset_map(~list(text_vectorizer(.x), .y))
# Train the model on the int sequences
cat("Training model...")
Training model...
%>%
model compile(optimizer="rmsprop", loss="mse") %>%
fit(train_dataset)
Epoch 1/10
1/1 - 0s - loss: 0.7798 - 251ms/epoch - 251ms/step
Epoch 2/10
1/1 - 0s - loss: 0.7140 - 5ms/epoch - 5ms/step
Epoch 3/10
1/1 - 0s - loss: 0.6691 - 3ms/epoch - 3ms/step
Epoch 4/10
1/1 - 0s - loss: 0.6330 - 3ms/epoch - 3ms/step
Epoch 5/10
1/1 - 0s - loss: 0.6022 - 3ms/epoch - 3ms/step
Epoch 6/10
1/1 - 0s - loss: 0.5748 - 3ms/epoch - 3ms/step
Epoch 7/10
1/1 - 0s - loss: 0.5500 - 3ms/epoch - 3ms/step
Epoch 8/10
1/1 - 0s - loss: 0.5272 - 4ms/epoch - 4ms/step
Epoch 9/10
1/1 - 0s - loss: 0.5059 - 4ms/epoch - 4ms/step
Epoch 10/10
1/1 - 0s - loss: 0.4859 - 4ms/epoch - 4ms/step
# For inference, you can export a model that accepts strings as input
<- layer_input(shape = 1, dtype="string")
input
<- input %>%
output text_vectorizer() %>%
model()
= keras_model(input, output)
end_to_end_model
# Call the end-to-end model on test data (which includes unknown tokens)
cat("Calling end-to-end model on test string...\n")
Calling end-to-end model on test string...
<- tf$constant(matrix("The one the other will absorb"))
test_data <- end_to_end_model(test_data)
test_output cat("Model output: "); print(test_output)
Model output:
tf.Tensor([[0.17933191]], shape=(1, 1), dtype=float32)
Important gotchas
Working with lookup layers with very large vocabularies
You may find yourself working with a very large vocabulary in a layer_text_vectorization()
, a layer_string_lookup()
layer, or an layer_integer_lookup()
layer. Typically, a vocabulary larger than 500MB would be considered “very large”.
In such case, for best performance, you should avoid using adapt()
. Instead, pre-compute your vocabulary in advance (you could use Apache Beam or TF Transform for this) and store it in a file. Then load the vocabulary into the layer at construction time by passing the filepath as the vocabulary
argument.
Environment Details
::tf_config() tensorflow
TensorFlow v2.13.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.3.1 (2023-06-16)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so; LAPACK version 3.10.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
time zone: America/New_York
tzcode source: system (glibc)
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] tfdatasets_2.9.0.9000 keras_2.13.0.9000 tensorflow_2.13.0.9000
loaded via a namespace (and not attached):
[1] vctrs_0.6.3 cli_3.6.1 knitr_1.43
[4] zeallot_0.1.0 rlang_1.1.1 xfun_0.40
[7] png_0.1-8 generics_0.1.3 jsonlite_1.8.7
[10] glue_1.6.2 htmltools_0.5.6 fansi_1.0.4
[13] rmarkdown_2.24 grid_4.3.1 tfruns_1.5.1
[16] evaluate_0.21 tibble_3.2.1 base64enc_0.1-3
[19] fastmap_1.1.1 yaml_2.3.7 lifecycle_1.0.3
[22] whisker_0.4.1 compiler_4.3.1 htmlwidgets_1.6.2
[25] Rcpp_1.0.11 pkgconfig_2.0.3 rstudioapi_0.15.0
[28] lattice_0.21-8 digest_0.6.33 R6_2.5.1
[31] tidyselect_1.2.0 reticulate_1.31.0.9000 utf8_1.2.3
[34] pillar_1.9.0 magrittr_2.0.3 Matrix_1.5-4.1
[37] tools_4.3.1
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.4.0
array-record==0.4.1
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
bleach==6.0.0
cachetools==5.3.1
certifi==2023.7.22
charset-normalizer==3.2.0
click==8.1.7
decorator==5.1.1
dm-tree==0.1.8
etils==1.4.1
executing==1.2.0
flatbuffers==23.5.26
gast==0.4.0
google-auth==2.22.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
googleapis-common-protos==1.60.0
grpcio==1.57.0
h5py==3.9.0
idna==3.4
importlib-resources==6.0.1
ipython==8.14.0
jedi==0.19.0
kaggle==1.5.16
keras==2.13.1
keras-tuner==1.3.5
kt-legacy==1.0.5
libclang==16.0.6
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
numpy==1.24.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cudnn-cu11==8.6.0.163
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.1
pandas==2.0.3
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==10.0.0
promise==2.3
prompt-toolkit==3.0.39
protobuf==3.20.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
pydot==1.4.2
Pygments==2.16.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-slugify==8.0.1
pytz==2023.3
requests==2.31.0
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.11.2
six==1.16.0
stack-data==0.6.2
tensorboard==2.13.0
tensorboard-data-server==0.7.1
tensorflow==2.13.0
tensorflow-datasets==4.9.2
tensorflow-estimator==2.13.0
tensorflow-hub==0.14.0
tensorflow-io-gcs-filesystem==0.33.0
tensorflow-metadata==1.14.0
termcolor==2.3.0
text-unidecode==1.3
toml==0.10.2
tqdm==4.66.1
traitlets==5.9.0
typing_extensions==4.5.0
tzdata==2023.3
urllib3==1.26.16
wcwidth==0.2.6
webencodings==0.5.1
Werkzeug==2.3.7
wrapt==1.15.0
zipp==3.16.2
TF Devices:
- PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
- PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
CPU cores: 12
Date rendered: 2023-08-28
Page render time: 1 minutes and 0 seconds