Dataset API

    Overview

    We can access the TensorFlow Dataset API via the tfdatasets package, which enables us to create scalable input pipelines that can be used with tfestimators. In this vignette, we demonstrate the capability to stream datasets stored on disk for training by building a classifier on the iris dataset.

    Dataset Preparation

    Let’s assume we’re given a dataset (which could be arbitrarily large) split into training and validation, and a small sample of the dataset. To simulate this scenario, we’ll create a few CSV files as follows:

    set.seed(123)
    train_idx <- sample(nrow(iris), nrow(iris) * 2/3)
    
    iris_train <- iris[train_idx,]
    iris_validation <- iris[-train_idx,]
    iris_sample <- iris_train %>%
      head(10)
    
    write.csv(iris_train, "iris_train.csv", row.names = FALSE)
    write.csv(iris_validation, "iris_validation.csv", row.names = FALSE)
    write.csv(iris_sample, "iris_sample.csv", row.names = FALSE)

    Estimator Construction

    We construct the classifier as usual – see Estimator Basics for details on feature columns and creating estimators.

    library(tfestimators)
    response <- "Species"
    features <- setdiff(names(iris), response)
    feature_columns <- feature_columns(
      column_numeric(features)
    )
    
    classifier <- dnn_classifier(
      feature_columns = feature_columns,
      hidden_units = c(16, 32, 16),
      n_classes = 3,
      label_vocabulary = c("setosa", "virginica", "versicolor")
    )

    Input Function

    The creation of the input function is similar to the in-memory case. However, instead of passing data frames or matrices to iris_input_fn(), we pass TensorFlow dataset objects which are internally iterators of the dataset files.

    iris_input_fn <- function(data) {
      input_fn(data, features = features, response = response)
    }
    
    iris_spec <- csv_record_spec("iris_sample.csv")
    iris_train <- text_line_dataset(
      "iris_train.csv", record_spec = iris_spec) %>%
      dataset_batch(10) %>% 
      dataset_repeat(10)
    iris_validation <- text_line_dataset(
      "iris_validation.csv", record_spec = iris_spec) %>%
      dataset_batch(10) %>%
      dataset_repeat(1)

    The csv_record_spec() function is a helper function that creates a specification from a sample file; the returned specification is required by the text_line_dataset() function to parse the files. There are many transformations available for dataset objects, but here we just demonstrate dataset_batch() and dataset_repeat() which control the batch size and how many times we iterate through the dataset files, respectively.

    Training and Evaluation

    Once the input functions and datasets are defined, the training and evaluation interface is exactly the same as in the in-memory case.

    history <- train(classifier, input_fn = iris_input_fn(iris_train))
    plot(history)
    predictions <- predict(classifier, input_fn = iris_input_fn(iris_validation))
    predictions
    evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_validation))
    evaluation

    Learning More

    See the documetnation for the tfdatasets package for additional details on using TensorFlow datasets.