library(tfdatasets)
<- range_dataset(1, 5, dtype = tf$int32) %>%
A dataset_map(function(x) tf$fill(list(x), x))
# Pad to the smallest per-batch size that fits all elements.
<- A %>% dataset_padded_batch(2)
B %>% as_array_iterator() %>% iterate(print)
B
# Pad to a fixed size.
<- A %>% dataset_padded_batch(2, padded_shapes=5)
C %>% as_array_iterator() %>% iterate(print)
C
# Pad with a custom value.
<- A %>% dataset_padded_batch(2, padded_shapes=5, padding_values = -1L)
D %>% as_array_iterator() %>% iterate(print)
D
# Pad with a single value and multiple components.
<- zip_datasets(A, A) %>% dataset_padded_batch(2, padding_values = -1L)
E %>% as_array_iterator() %>% iterate(print) E
dataset_padded_batch
Combines consecutive elements of this dataset into padded batches.
Description
Combines consecutive elements of this dataset into padded batches.
Usage
dataset_padded_batch(
dataset,
batch_size, padded_shapes = NULL,
padding_values = NULL,
drop_remainder = FALSE,
name = NULL
)
Arguments
Arguments | Description |
---|---|
dataset | A dataset |
batch_size | An integer, representing the number of consecutive elements of this dataset to combine in a single batch. |
padded_shapes | (Optional.) A (nested) structure of tf.TensorShape (returned by tensorflow::shape() ) or tf$int64 vector tensor-like objects representing the shape to which the respective component of each input element should be padded prior to batching. Any unknown dimensions will be padded to the maximum size of that dimension in each batch. If unset, all dimensions of all components are padded to the maximum size in the batch. padded_shapes must be set if any component has an unknown rank. |
padding_values | (Optional.) A (nested) structure of scalar-shaped tf.Tensor , representing the padding values to use for the respective components. NULL represents that the (nested) structure should be padded with default values. Defaults are 0 for numeric types and the empty string "" for string types. The padding_values should have the same (nested) structure as the input dataset. If padding_values is a single element and the input dataset has multiple components, then the same padding_values will be used to pad every component of the dataset. If padding_values is a scalar, then its value will be broadcasted to match the shape of each component. |
drop_remainder | (Optional.) A boolean scalar, representing whether the last batch should be dropped in the case it has fewer than batch_size elements; the default behavior is not to drop the smaller batch. |
name | (Optional.) A name for the tf.data operation. Requires tensorflow version >= 2.7. |
Details
This transformation combines multiple consecutive elements of the input dataset into a single element. Like dataset_batch()
, the components of the resulting element will have an additional outer dimension, which will be batch_size
(or N %% batch_size
for the last element if batch_size
does not divide the number of input elements N
evenly and drop_remainder
is FALSE
). If your program depends on the batches having the same outer dimension, you should set the drop_remainder
argument to TRUE
to prevent the smaller batch from being produced. Unlike dataset_batch()
, the input elements to be batched may have different shapes, and this transformation will pad each component to the respective shape in padded_shapes
. The padded_shapes
argument determines the resulting shape for each dimension of each component in an output element:
If the dimension is a constant, the component will be padded out to that length in that dimension.
If the dimension is unknown, the component will be padded out to the maximum length of all elements in that dimension.
See also
tf$data$experimental$dense_to_sparse_batch
, which combines elements that may have different shapes into atf$sparse$SparseTensor
.
Value
A tf_dataset
Examples
See Also
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#padded_batch
Other dataset methods:
dataset_batch()
,dataset_cache()
,dataset_collect()
,dataset_concatenate()
,dataset_decode_delim()
,dataset_filter()
,dataset_interleave()
,dataset_map_and_batch()
,dataset_map()
,dataset_prefetch_to_device()
,dataset_prefetch()
,dataset_reduce()
,dataset_repeat()
,dataset_shuffle_and_repeat()
,dataset_shuffle()
,dataset_skip()
,dataset_take_while()
,dataset_take()
,dataset_window()