library(tensorflow)
library(keras)
Working with RNNs
Guide to using and customizing recurrent neural networks, a class of neural networks for modeling sequence data such as time series or natural language.
Introduction
Recurrent neural networks (RNN) are a class of neural networks that is powerful for modeling sequence data such as time series or natural language.
Schematically, a RNN layer uses a for
loop to iterate over the timesteps of a sequence, while maintaining an internal state that encodes information about the timesteps it has seen so far.
The Keras RNN API is designed with a focus on:
Ease of use: the built-in
layer_rnn()
,layer_lstm()
,layer_gru()
layers enable you to quickly build recurrent models without having to make difficult configuration choices.Ease of customization: You can also define your own RNN cell layer (the inner part of the
for
loop) with custom behavior, and use it with the genericlayer_rnn
layer (thefor
loop itself). This allows you to quickly prototype different research ideas in a flexible way with minimal code.
Setup
Built-in RNN layers: a simple example
There are three built-in RNN layers in Keras:
layer_simple_rnn()
, a fully-connected RNN where the output from the previous timestep is to be fed to the next timestep.layer_gru()
, first proposed in Cho et al., 2014.layer_lstm()
, first proposed in Hochreiter & Schmidhuber, 1997.
Here is a simple example of a sequential model that processes sequences of integers, embeds each integer into a 64-dimensional vector, then processes the sequence of vectors using a layer_lstm()
.
<- keras_model_sequential() %>%
model
# Add an Embedding layer expecting input vocab of size 1000, and
# output embedding dimension of size 64.
layer_embedding(input_dim = 1000, output_dim = 64) %>%
# Add a LSTM layer with 128 internal units.
layer_lstm(128) %>%
# Add a Dense layer with 10 units.
layer_dense(10)
model
Model: "sequential"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
embedding (Embedding) (None, None, 64) 64000
lstm (LSTM) (None, 128) 98816
dense (Dense) (None, 10) 1290
============================================================================
Total params: 164,106
Trainable params: 164,106
Non-trainable params: 0
____________________________________________________________________________
Built-in RNNs support a number of useful features:
- Recurrent dropout, via the
dropout
andrecurrent_dropout
arguments - Ability to process an input sequence in reverse, via the
go_backwards
argument - Loop unrolling (which can lead to a large speedup when processing short sequences on CPU), via the
unroll
argument - …and more.
For more information, see the RNN API documentation.
Outputs and states
By default, the output of a RNN layer contains a single vector per sample. This vector is the RNN cell output corresponding to the last timestep, containing information about the entire input sequence. The shape of this output is (batch_size, units)
where units
corresponds to the units
argument passed to the layer’s constructor.
A RNN layer can also return the entire sequence of outputs for each sample (one vector per timestep per sample), if you set return_sequences = TRUE
. The shape of this output is (batch_size, timesteps, units)
.
<- keras_model_sequential() %>%
model layer_embedding(input_dim = 1000, output_dim = 64) %>%
# The output of GRU will be a 3D tensor of shape (batch_size, timesteps, 256)
layer_gru(256, return_sequences = TRUE) %>%
# The output of SimpleRNN will be a 2D tensor of shape (batch_size, 128)
layer_simple_rnn(128) %>%
layer_dense(10)
model
Model: "sequential_1"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
embedding_1 (Embedding) (None, None, 64) 64000
gru (GRU) (None, None, 256) 247296
simple_rnn (SimpleRNN) (None, 128) 49280
dense_1 (Dense) (None, 10) 1290
============================================================================
Total params: 361,866
Trainable params: 361,866
Non-trainable params: 0
____________________________________________________________________________
In addition, a RNN layer can return its final internal state(s). The returned states can be used to resume the RNN execution later, or to initialize another RNN. This setting is commonly used in the encoder-decoder sequence-to-sequence model, where the encoder final state is used as the initial state of the decoder.
To configure a RNN layer to return its internal state, set return_state = TRUE
when creating the layer. Note that LSTM
has 2 state tensors, but GRU
only has one.
To configure the initial state of the layer, call the layer instance with the additional named argument initial_state
. Note that the shape of the state needs to match the unit size of the layer, like in the example below.
<- 1000
encoder_vocab <- 2000
decoder_vocab
<- layer_input(shape(NULL))
encoder_input <- encoder_input %>%
encoder_embedded layer_embedding(input_dim=encoder_vocab, output_dim=64)
# Return states in addition to output
c(output, state_h, state_c) %<-%
layer_lstm(encoder_embedded, units = 64, return_state=TRUE, name="encoder")
<- list(state_h, state_c)
encoder_state
<- layer_input(shape(NULL))
decoder_input <- decoder_input %>%
decoder_embedded layer_embedding(input_dim = decoder_vocab, output_dim = 64)
# Pass the 2 states to a new LSTM layer, as initial state
<- layer_lstm(units = 64, name = "decoder")
decoder_lstm_layer <- decoder_lstm_layer(decoder_embedded, initial_state = encoder_state)
decoder_output
<- decoder_output %>% layer_dense(10)
output
<- keras_model(inputs = list(encoder_input, decoder_input),
model outputs = output)
model
Model: "model"
____________________________________________________________________________
Layer (type) Output Shape Param # Connected to
============================================================================
input_1 (InputLayer) [(None, None)] 0 []
input_2 (InputLayer) [(None, None)] 0 []
embedding_2 (Embedding) (None, None, 6 64000 ['input_1[0][0]']
4)
embedding_3 (Embedding) (None, None, 6 128000 ['input_2[0][0]']
4)
encoder (LSTM) [(None, 64), 33024 ['embedding_2[0][0]']
(None, 64),
(None, 64)]
decoder (LSTM) (None, 64) 33024 ['embedding_3[0][0]',
'encoder[0][1]',
'encoder[0][2]']
dense_2 (Dense) (None, 10) 650 ['decoder[0][0]']
============================================================================
Total params: 258,698
Trainable params: 258,698
Non-trainable params: 0
____________________________________________________________________________
RNN layers and RNN cells
In addition to the built-in RNN layers, the RNN API also provides cell-level APIs. Unlike RNN layers, which process whole batches of input sequences, the RNN cell only processes a single timestep.
The cell is the inside of the for
loop of a RNN layer. Wrapping a cell inside a layer_rnn()
layer gives you a layer capable of processing a sequence, e.g. layer_rnn(layer_lstm_cell(10))
.
Mathematically, layer_rnn(layer_lstm_cell(10))
produces the same result as layer_lstm(10)
. In fact, the implementation of this layer in TF v1.x was just creating the corresponding RNN cell and wrapping it in a RNN layer. However using the built-in layer_gru()
and layer_lstm()
layers enable the use of CuDNN and you may see better performance.
There are three built-in RNN cells, each of them corresponding to the matching RNN layer.
layer_simple_rnn_cell()
corresponds to thelayer_simple_rnn()
layer.layer_gru_cell
corresponds to thelayer_gru
layer.layer_lstm_cell
corresponds to thelayer_lstm
layer.
The cell abstraction, together with the generic layer_rnn()
class, makes it very easy to implement custom RNN architectures for your research.
Cross-batch statefulness
When processing very long (possibly infinite) sequences, you may want to use the pattern of cross-batch statefulness.
Normally, the internal state of a RNN layer is reset every time it sees a new batch (i.e. every sample seen by the layer is assumed to be independent of the past). The layer will only maintain a state while processing a given sample.
If you have very long sequences though, it is useful to break them into shorter sequences, and to feed these shorter sequences sequentially into a RNN layer without resetting the layer’s state. That way, the layer can retain information about the entirety of the sequence, even though it’s only seeing one sub-sequence at a time.
You can do this by setting stateful = TRUE
in the constructor.
If you have a sequence s = c(t0, t1, ... t1546, t1547)
, you would split it into e.g.
= c(t0, t1, ..., t100)
s1 = c(t101, ..., t201)
s2
...= c(t1501, ..., t1547) s16
Then you would process it via:
<- layer_lstm(units = 64, stateful = TRUE)
lstm_layer for(s in sub_sequences)
<- lstm_layer(s) output
When you want to clear the state, you can use layer$reset_states()
.
Note: In this setup, sample
i
in a given batch is assumed to be the continuation of samplei
in the previous batch. This means that all batches should contain the same number of samples (batch size). E.g. if a batch contains[sequence_A_from_t0_to_t100, sequence_B_from_t0_to_t100]
, the next batch should contain[sequence_A_from_t101_to_t200, sequence_B_from_t101_to_t200]
.
Here is a complete example:
<- k_random_uniform(c(20, 10, 50), dtype = "float32")
paragraph1 <- k_random_uniform(c(20, 10, 50), dtype = "float32")
paragraph2 <- k_random_uniform(c(20, 10, 50), dtype = "float32")
paragraph3
<- layer_lstm(units = 64, stateful = TRUE)
lstm_layer <- lstm_layer(paragraph1)
output <- lstm_layer(paragraph2)
output <- lstm_layer(paragraph3)
output
# reset_states() will reset the cached state to the original initial_state.
# If no initial_state was provided, zero-states will be used by default.
$reset_states() lstm_layer
RNN State Reuse
The recorded states of the RNN layer are not included in the layer$weights()
. If you would like to reuse the state from a RNN layer, you can retrieve the states value by layer$states
and use it as the initial state of a new layer instance via the Keras functional API like new_layer(inputs, initial_state = layer$states)
, or model subclassing.
Please also note that a sequential model cannot be used in this case since it only supports layers with single input and output. The extra input of initial state makes it impossible to use here.
<- k_random_uniform(c(20, 10, 50), dtype = "float32")
paragraph1 <- k_random_uniform(c(20, 10, 50), dtype = "float32")
paragraph2 <- k_random_uniform(c(20, 10, 50), dtype = "float32")
paragraph3
<- layer_lstm(units = 64, stateful = TRUE)
lstm_layer <- lstm_layer(paragraph1)
output <- lstm_layer(paragraph2)
output
<- lstm_layer$states
existing_state
<- layer_lstm(units = 64)
new_lstm_layer <- new_lstm_layer(paragraph3, initial_state = existing_state) new_output
Bidirectional RNNs
For sequences other than time series (e.g. text), it is often the case that a RNN model can perform better if it not only processes sequence from start to end, but also backwards. For example, to predict the next word in a sentence, it is often useful to have the context around the word, not only just the words that come before it.
Keras provides an easy API for you to build such bidirectional RNNs: the bidirectional()
wrapper.
<- keras_model_sequential(input_shape = shape(5, 10)) %>%
model bidirectional(layer_lstm(units = 64, return_sequences = TRUE)) %>%
bidirectional(layer_lstm(units = 32)) %>%
layer_dense(10)
model
Model: "sequential_2"
____________________________________________________________________________
Layer (type) Output Shape Param #
============================================================================
bidirectional_1 (Bidirectional) (None, 5, 128) 38400
bidirectional (Bidirectional) (None, 64) 41216
dense_3 (Dense) (None, 10) 650
============================================================================
Total params: 80,266
Trainable params: 80,266
Non-trainable params: 0
____________________________________________________________________________
Under the hood, bidirectional()
will copy the RNN layer passed in, and flip the go_backwards
field of the newly copied layer, so that it will process the inputs in reverse order.
The output of the bidirectional
RNN will be, by default, the concatenation of the forward layer output and the backward layer output. If you need a different merging behavior, e.g. averaging, change the merge_mode
parameter in the bidirectional
wrapper constructor. For more details about bidirectional
, please check the API docs.
Performance optimization and CuDNN kernels
In TensorFlow 2.0, the built-in LSTM and GRU layers have been updated to leverage CuDNN kernels by default when a GPU is available. With this change, the prior layer_cudnn_gru/layer_cudnn_lstm
layers have been deprecated, and you can build your model without worrying about the hardware it will run on.
Since the CuDNN kernel is built with certain assumptions, this means the layer will not be able to use the CuDNN kernel if you change the defaults of the built-in LSTM or GRU layers. E.g.:
- Changing the
activation
function from"tanh"
to something else. - Changing the
recurrent_activation
function from"sigmoid"
to something else. - Using
recurrent_dropout > 0
. - Setting
unroll
toTRUE
, which forces LSTM/GRU to decompose the innertf$while_loop
into an unrolledfor
loop. - Setting
use_bias
toFALSE
. - Using masking when the input data is not strictly right padded (if the mask corresponds to strictly right padded data, CuDNN can still be used. This is the most common case).
For the detailed list of constraints, please see the documentation for the LSTM and GRU layers.
Using CuDNN kernels when available
Let’s build a simple LSTM model to demonstrate the performance difference.
We’ll use as input sequences the sequence of rows of MNIST digits (treating each row of pixels as a timestep), and we’ll predict the digit’s label.
<- 64
batch_size # Each MNIST image batch is a tensor of shape (batch_size, 28, 28).
# Each input sequence will be of size (28, 28) (height is treated like time).
<- 28
input_dim
<- 64
units <- 10 # labels are from 0 to 9
output_size
# Build the RNN model
<- function(allow_cudnn_kernel = TRUE) {
build_model # CuDNN is only available at the layer level, and not at the cell level.
# This means `layer_lstm(units = units)` will use the CuDNN kernel,
# while layer_rnn(cell = layer_lstm_cell(units)) will run on non-CuDNN kernel.
if (allow_cudnn_kernel)
# The LSTM layer with default options uses CuDNN.
<- layer_lstm(units = units)
lstm_layer else
# Wrapping a LSTMCell in a RNN layer will not use CuDNN.
<- layer_rnn(cell = layer_lstm_cell(units = units))
lstm_layer
<-
model keras_model_sequential(input_shape = shape(NULL, input_dim)) %>%
lstm_layer() %>%
layer_batch_normalization() %>%
layer_dense(output_size)
model }
Let’s load the MNIST dataset:
<- dataset_mnist()
mnist $train$x <- mnist$train$x / 255
mnist$test$x <- mnist$test$x / 255
mnistc(sample, sample_label) %<-% with(mnist$train, list(x[1,,], y[1]))
Let’s create a model instance and train it.
We choose sparse_categorical_crossentropy()
as the loss function for the model. The output of the model has shape of (batch_size, 10)
. The target for the model is an integer vector, each of the integer is in the range of 0 to 9.
<- build_model(allow_cudnn_kernel = TRUE) %>%
model compile(
loss = loss_sparse_categorical_crossentropy(from_logits = TRUE),
optimizer = "sgd",
metrics = "accuracy"
)
%>% fit(
model $train$x,
mnist$train$y,
mnistvalidation_data = with(mnist$test, list(x, y)),
batch_size = batch_size,
epochs = 1
)
Now, let’s compare to a model that does not use the CuDNN kernel:
<- build_model(allow_cudnn_kernel=FALSE)
noncudnn_model $set_weights(model$get_weights())
noncudnn_model%>% compile(
noncudnn_model loss=loss_sparse_categorical_crossentropy(from_logits=TRUE),
optimizer="sgd",
metrics="accuracy",
)
%>% fit(
noncudnn_model $train$x,
mnist$train$y,
mnistvalidation_data = with(mnist$test, list(x, y)),
batch_size = batch_size,
epochs = 1
)
When running on a machine with a NVIDIA GPU and CuDNN installed, the model built with CuDNN is much faster to train compared to the model that uses the regular TensorFlow kernel.
The same CuDNN-enabled model can also be used to run inference in a CPU-only environment. The tf$device()
annotation below is just forcing the device placement. The model will run on CPU by default if no GPU is available.
You simply don’t have to worry about the hardware you’re running on anymore. Isn’t that pretty cool?
with(tf$device("CPU:0"), {
<- build_model(allow_cudnn_kernel=TRUE)
cpu_model $set_weights(model$get_weights())
cpu_model
<- cpu_model %>%
result predict_on_batch(k_expand_dims(sample, 1)) %>%
k_argmax(axis = 2)
cat(sprintf(
"Predicted result is: %s, target result is: %s\n", as.numeric(result), sample_label))
# show mnist image
%>%
sample apply(2, rev) %>% # flip
t() %>% # rotate
image(axes = FALSE, asp = 1, col = grey(seq(0, 1, length.out = 256)))
})
Predicted result is: 3, target result is: 5
RNNs with list/dict inputs, or nested inputs
Nested structures allow implementers to include more information within a single timestep. For example, a video frame could have audio and video input at the same time. The data shape in this case could be:
[batch, timestep, {"video": [height, width, channel], "audio": [frequency]}]
In another example, handwriting data could have both coordinates x and y for the current position of the pen, as well as pressure information. So the data representation could be:
[batch, timestep, {"location": [x, y], "pressure": [force]}]
The following code provides an example of how to build a custom RNN cell that accepts such structured inputs.
Define a custom cell that supports nested input/output
See Making new Layers & Models via subclassing for details on writing your own layers.
NestedCell(keras$layers$Layer) %py_class% {
<- function(unit_1, unit_2, unit_3, ...) {
initialize $unit_1 <- unit_1
self$unit_2 <- unit_2
self$unit_3 <- unit_3
self$state_size <- list(shape(unit_1), shape(unit_2, unit_3))
self$output_size <- list(shape(unit_1), shape(unit_2, unit_3))
self$initialize(...)
super
}
<- function(self, input_shapes) {
build # expect input_shape to contain 2 items, [(batch, i1), (batch, i2, i3)]
# dput(input_shapes) gives: list(list(NULL, 32L), list(NULL, 64L, 32L))
<- input_shapes[[c(1, 2)]] # 32
i1 <- input_shapes[[c(2, 2)]] # 64
i2 <- input_shapes[[c(2, 3)]] # 32
i3
$kernel_1 = self$add_weight(
selfshape = shape(i1, self$unit_1),
initializer = "uniform",
name = "kernel_1"
)$kernel_2_3 = self$add_weight(
selfshape = shape(i2, i3, self$unit_2, self$unit_3),
initializer = "uniform",
name = "kernel_2_3"
)
}
<- function(inputs, states) {
call # inputs should be in [(batch, input_1), (batch, input_2, input_3)]
# state should be in shape [(batch, unit_1), (batch, unit_2, unit_3)]
# Don't forget you can call `browser()` here while the layer is being traced!
c(input_1, input_2) %<-% tf$nest$flatten(inputs)
c(s1, s2) %<-% states
<- tf$matmul(input_1, self$kernel_1)
output_1 <- tf$einsum("bij,ijkl->bkl", input_2, self$kernel_2_3)
output_2_3 <- s1 + output_1
state_1 <- s2 + output_2_3
state_2_3
<- tuple(output_1, output_2_3)
output <- tuple(state_1, state_2_3)
new_states
tuple(output, new_states)
}
<- function() {
get_config list("unit_1" = self$unit_1,
"unit_2" = self$unit_2,
"unit_3" = self$unit_3)
} }
Build a RNN model with nested input/output
Let’s build a Keras model that uses a layer_rnn
layer and the custom cell we just defined.
<- 10
unit_1 <- 20
unit_2 <- 30
unit_3
<- 32
i1 <- 64
i2 <- 32
i3 <- 64
batch_size <- 10
num_batches <- 50
timestep
<- NestedCell(unit_1, unit_2, unit_3)
cell <- layer_rnn(cell = cell)
rnn
= layer_input(shape(NULL, i1))
input_1 = layer_input(shape(NULL, i2, i3))
input_2
= rnn(tuple(input_1, input_2))
outputs
= keras_model(list(input_1, input_2), outputs)
model
%>% compile(optimizer="adam", loss="mse", metrics="accuracy") model
Train the model with randomly generated data
Since there isn’t a good candidate dataset for this model, we use random data for demonstration.
<- k_random_uniform(c(batch_size * num_batches, timestep, i1))
input_1_data <- k_random_uniform(c(batch_size * num_batches, timestep, i2, i3))
input_2_data <- k_random_uniform(c(batch_size * num_batches, unit_1))
target_1_data <- k_random_uniform(c(batch_size * num_batches, unit_2, unit_3))
target_2_data <- list(input_1_data, input_2_data)
input_data <- list(target_1_data, target_2_data)
target_data
%>% fit(input_data, target_data, batch_size=batch_size) model
With keras::layer_rnn()
, you are only expected to define the math logic for an individual step within the sequence, and the layer_rnn()
will handle the sequence iteration for you. It’s an incredibly powerful way to quickly prototype new kinds of RNNs (e.g. a LSTM variant).
For more details, please visit the API docs.
Environment Details
::tf_config() tensorflow
TensorFlow v2.11.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)
sessionInfo()
R version 4.2.1 (2022-06-23)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.5 LTS
Matrix products: default
BLAS: /home/tomasz/opt/R-4.2.1/lib/R/lib/libRblas.so
LAPACK: /usr/lib/x86_64-linux-gnu/libmkl_intel_lp64.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] keras_2.9.0.9000 tensorflow_2.9.0.9000
loaded via a namespace (and not attached):
[1] Rcpp_1.0.9 pillar_1.8.1 compiler_4.2.1
[4] base64enc_0.1-3 tools_4.2.1 zeallot_0.1.0
[7] digest_0.6.31 jsonlite_1.8.4 evaluate_0.18
[10] lifecycle_1.0.3 tibble_3.1.8 lattice_0.20-45
[13] pkgconfig_2.0.3 png_0.1-8 rlang_1.0.6
[16] Matrix_1.5-3 cli_3.4.1 yaml_2.3.6
[19] xfun_0.35 fastmap_1.1.0 stringr_1.5.0
[22] knitr_1.41 generics_0.1.3 vctrs_0.5.1
[25] htmlwidgets_1.5.4 rprojroot_2.0.3 grid_4.2.1
[28] reticulate_1.26-9000 glue_1.6.2 here_1.0.1
[31] R6_2.5.1 fansi_1.0.3 rmarkdown_2.18
[34] magrittr_2.0.3 whisker_0.4.1 htmltools_0.5.4
[37] tfruns_1.5.1 utf8_1.2.2 stringi_1.7.8
system2(reticulate::py_exe(), c("-m pip freeze"), stdout = TRUE) |> writeLines()
absl-py==1.3.0
asttokens==2.2.1
astunparse==1.6.3
backcall==0.2.0
cachetools==5.2.0
certifi==2022.12.7
charset-normalizer==2.1.1
decorator==5.1.1
dill==0.3.6
etils==0.9.0
executing==1.2.0
flatbuffers==22.12.6
gast==0.4.0
google-auth==2.15.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.57.0
grpcio==1.51.1
h5py==3.7.0
idna==3.4
importlib-resources==5.10.1
ipython==8.7.0
jedi==0.18.2
kaggle==1.5.12
keras==2.11.0
keras-tuner==1.1.3
kt-legacy==1.0.4
libclang==14.0.6
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib-inline==0.1.6
numpy==1.23.5
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.2
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
promise==2.3
prompt-toolkit==3.0.36
protobuf==3.19.6
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pydot==1.4.2
Pygments==2.13.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-slugify==7.0.0
pytz==2022.6
PyYAML==6.0
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scipy==1.9.3
six==1.16.0
stack-data==0.6.2
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.12.0
termcolor==2.1.1
text-unidecode==1.3
toml==0.10.2
tqdm==4.64.1
traitlets==5.7.1
typing_extensions==4.4.0
urllib3==1.26.13
wcwidth==0.2.5
Werkzeug==2.2.2
wrapt==1.14.1
zipp==3.11.0
TF Devices:
- PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
- PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
CPU cores: 12
Date rendered: 2022-12-16
Page render time: 39 seconds