Input Functions

Overview

TensorFlow estimators receive data through input functions. Input functions take an arbitrary data source (in-memory data sets, streaming data, custom data format, and so on) and generate Tensors that can be supplied to TensorFlow models.

More concretely, input functions are used to:

  1. Turn raw data sources into Tensors, and
  2. Configure how data is drawn during training (shuffling, batch size, epochs, etc.)

You can also perform feature engineering within an input function; however, it’s better to use feature columns for this purpose whenever possible, as in that case the tranformations are made part of the TensorFlow graph and so can be executed without an R runtime (e.g. when the model is deployed onto a device or server).

The tfestimators package includes an input_fn() function that can create TensorFlow input functions from common R data sources (e.g. data frames and matrices). It’s also possible to write a fully custom input function. Both methods of creating input functions are covered below.

Data Frame Input

You can create an input function from an R data frame using the input_fn() method. You can specify feature and response variables either explicitly or using the R formula interface.

For example, to create an input function for the mtcars dataset with features “drat” and “cyl” and response “mpg” you could use this code:

model %>% train(
  input_fn(mtcars, 
           features = c(drat, cyl), 
           response = mpg,
           batch_size = 128,
           epochs = 3)
)

Or alternatively use the R formula interface like this:

model %>% train(
  input_fn(mpg ~ drat + cyl, 
           data = mtcars,
           batch_size = 128,
           epochs = 3)
)

Note that input_fn functions provide several parameters for controlling how data is drawn from the input source. These include batch_size (defaults to 128), shuffle (default to "auto"), and epochs (defaults to 1). Note that, by default, shuffling is disabled during prediction.

Training vs. Evaluation

It’s often the case that you’ll want to use the same basic input function for training and evaluation, but need to provide a distinct dataset for each step. In that case you can create a wrapper function that returns the same input function with varying input data.

For example, imagine we have already split the mtcars dataset into training and test subsets. We could have an input function generator like this:

mtcars_input_fn <- function(data, ...) {
  input_fn(data,
           features = c("drat", "cyl"),
           response = "mpg",
           ...)
}

The ... parameter is used to forward additional options to input_fn().

This helper function could then be used during training and evaluation as follows:

# train the model
model %>% train(mtcars_input_fn(train_data))

# evaluate the model
model %>% evaluate(mtcars_input_fn(test_data))

Matrix Input

As with data frames, you can also pass an R matrix to input_fn() to automatically create an input function for the matrix. Note however that in order to specify the features and response parameters you will need to ensure that your matrix columns are named. For example:

m <- matrix(c(1:12), nrow = 4, ncol = 3)
colnames(m) <- c("x1", "x2", "y")
input_fn(m, features = c("x1", "x2"), response = "y")

List Input

There’s also a built-in input_fn() that works on nested lists, for example:

input_fn(
  object = list(
    inputs = list(
      list(list(1), list(2), list(3)),
      list(list(4), list(5), list(6))),
    output = list(
      list(1, 2, 3), list(4, 5, 6))),
  features = "inputs",
  response = "output"
)

In the above example, the data is a list of two named lists where each named list can be seen as different columns in a dataset. In this case, a column named features is being used as features to the model and a column named response is being used as the response variable.

This nested lists format is particularly useful when constructing sequence input to Recurrent Neural Networks (RNN). Once the data is defined using input_fn(), it can be used directly in the RNN constructor.