R Interface to TensorFlow Estimators

R Interface to TensorFlow Estimators

Travis-CI Build Status


The tfestimators package is an R interface to TensorFlow Estimators, a high-level API that provides:

  • Implementations of many different model types including linear models and deep neural networks. More models are coming soon such as state saving recurrent neural networks, dynamic recurrent neural networks, support vector machines, random forest, KMeans clustering, etc.

  • A flexible framework for defining arbitrary new model types as custom estimators.

For more details on the architecture and design of TensorFlow Estimators, please see the white paper: TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks.

Quick Start


To use tfestimators, you need to install both the R package as well as TensorFlow itself.

First, install the tfestimators R package as follows:


Then, use the install_tensorflow() function to install TensorFlow (note that the tfestimators package requires version 1.3 or higher of TensorFlow so even if you already have TensorFlow installed you should update if you are running a previous version):


This will provide you with a default installation of TensorFlow suitable for getting started. See the article on installation to learn about more advanced options, including installing a version of TensorFlow that takes advantage of NVIDIA GPUs if you have the correct CUDA libraries installed.

Simple Example

Let’s create a simple linear regression model with the mtcars dataset to demonstrate the use of estimators. We’ll illustrate how ‘input functions’ can be constructed and used to feed data to an estimator, how ‘feature columns’ can be used to specify a set of transformations to apply to input data, and how these pieces come together in the Estimator interface.

Input Function

Estimators can accept data from arbitrary data sources through an ‘input function’. The tfestimators package provides the input_fn() helper function for generating input functions from common R data structures, e.g. R matrices and data frames.

Here, we define a helper function that will return an input function for a subset of our mtcars data set.


# return an input_fn for a given subset of data
mtcars_input_fn <- function(data) {
           features = c("disp", "cyl"), 
           response = "mpg")

Feature Columns

Next, we define the feature columns for our model. Feature columns are mappings of raw input data to the data that we’ll actually feed into our training, evaluation, and prediction steps. Here, we create a list of feature columns containing the disp and cyl variables:

cols <- feature_columns(

You can also define multiple feature columns at once:

cols <- feature_columns( 
  column_numeric("disp", "cyl")

By using the family of feature column functions we can define various transformations on the data before using it for modeling.


Next, we create the estimator by calling the linear_regressor() function and passing it a set of feature columns:

model <- linear_regressor(feature_columns = cols)


We’re now ready to train our model, using the train() function. We’ll partition the mtcars data set into separate training and validation data sets, and feed the training data set into train(). We’ll hold 20% of the data aside for validation.

indices <- sample(1:nrow(mtcars), size = 0.80 * nrow(mtcars))
train <- mtcars[indices, ]
test  <- mtcars[-indices, ]

# train the model
model %>% train(mtcars_input_fn(train))


We can evaluate the model’s accuracy using the evaluate() function, using our ‘test’ data set for validation.

model %>% evaluate(mtcars_input_fn(test))


After we’ve finished training out model, we can use it to generate predictions from new data.

obs <- mtcars[1:3, ]
model %>% predict(mtcars_input_fn(obs))

Learning More

These articles cover the basics of using TensorFlow Estimators:

These articles describe more advanced topics/usage:

One of the best ways to learn is from reviewing and experimenting with examples. See the Examples page for a variety of examples to help you get started.