Getting Started with tfautograph

Learn how you can use R control flow expressions like if, while and for with TensorFlow, in both eager and graph modes.

The R function tfautograph::autograph() helps you write TensorFlow code using R control flow. It allows you to use tensors in R control flow expressions like if, while, and for, which can automatically be translated to build a tensorflow graph (hence the name, auto graph).

This guide goes through some of the main features and then goes into how it works a little.

Before we get started a clarification: The R package {tfautograph} is inspired by the functionality (and name) of the tf.autograph submodule in the python interface to Tensorflow, but it has little relation to that code base. If you’re interested in writing Tensorflow code in R using native R syntax, then keep reading because you’re in the right place. If you’re looking for an interface to the tf.autograph Python submodule from the R, this is probably not what you’re looking for.

Setup

library(magrittr)
library(tensorflow)
library(tfdatasets)

library(tfautograph)

Compatibility

tfautograph works with Tensorflow versions 1.15 and >= 2.0. This guide is rendered using version 2.11.

tf_config()
## TensorFlow v2.11.0 (~/.virtualenvs/r-tensorflow-website/lib/python3.10/site-packages/tensorflow)
## Python v3.10 (~/.virtualenvs/r-tensorflow-website/bin/python)

Usage

The primary workhorse that enables usage of control flow in tf_function() is tfautograph::autograph(). For most use-cases, you will not need to call autograph() directly, it is automatically invoked when calling tf_function(fn, autograph = TRUE), (the default). However, you can also invoke it directly. It can either take a function or an expression. The following two uses are equivalent:

# pass a function to autograph()
fn <- function(x) if(x > 0) x * x else x
square_if_positive <- autograph(fn)

# pass an expression to autograph()
square_if_positive <- function(x) autograph(if(x > 0) x * x else x)

Now square_if_positive() is a function that can accept a tensor as an argument.

x <- tf$convert_to_tensor(5)
y <- tf$convert_to_tensor(-5)
square_if_positive(x)
## tf.Tensor(25.0, shape=(), dtype=float32)
square_if_positive(y)
## tf.Tensor(-5.0, shape=(), dtype=float32)

Note that if you’re in a context where TensorFlow is executing eagerly, autograph() doesn’t change that–square_if_positive() is still executing eagerly. You can test that by inserting some R message() calls to see when a branch is evaluated.

square_if_positive_verbose <- autograph(function(x) {
  if (x > 0) {
    message("Evaluating true branch")
    x * x
  } else {
    message("Evaluating false branch")
    x
  }
})

square_if_positive_verbose(x)
## Evaluating true branch
## tf.Tensor(25.0, shape=(), dtype=float32)
square_if_positive_verbose(x)
## Evaluating true branch
## tf.Tensor(25.0, shape=(), dtype=float32)
square_if_positive_verbose(x)
## Evaluating true branch
## tf.Tensor(25.0, shape=(), dtype=float32)

square_if_positive_verbose(y)
## Evaluating false branch
## tf.Tensor(-5.0, shape=(), dtype=float32)
square_if_positive_verbose(y)
## Evaluating false branch
## tf.Tensor(-5.0, shape=(), dtype=float32)
square_if_positive_verbose(y)
## Evaluating false branch
## tf.Tensor(-5.0, shape=(), dtype=float32)

As you can see we’re in eager mode, meaning, the R code of the function body is evaluated every time the function is called.

The easiest way to enter a context where TensorFlow is not executing eagerly anymore and instead is in graph mode is to call tf_function().

graph_fn <- tf_function(square_if_positive_verbose)

graph_fn(x)
## Evaluating true branch
## Evaluating false branch
## tf.Tensor(25.0, shape=(), dtype=float32)
graph_fn(x)
## tf.Tensor(25.0, shape=(), dtype=float32)
graph_fn(x)
## tf.Tensor(25.0, shape=(), dtype=float32)

graph_fn(y)
## tf.Tensor(-5.0, shape=(), dtype=float32)
graph_fn(y)
## tf.Tensor(-5.0, shape=(), dtype=float32)
graph_fn(y)
## tf.Tensor(-5.0, shape=(), dtype=float32)

In graph mode, both branches of the if expression are traced into a TensorFlow graph the first time the function is called and the resultant graph is cached by TensorFlow Function object returned from tf_function(). Then on subsequent calls of the Function, only the cached graph is evaluated.

The key takeaways are that autograph() helps you write natural R code and use tensors in expressions where R wouldn’t otherwise accept them. And that autograph() is smart enough to do the right thing in both eager mode and graph mode.

Control Flow

tfautograph can translate R control flow statements if, while, for, break, and next to tensorflow. Here is summary table of the translation endpoints (note, these summary code snippets are meant to concisely convey the spirit of the translation, not the actual implementation)

R expression Graph Mode Translation Eager Mode Translation
if(x) tf$cond(x, ...) if(x$`__bool__`())
while(x) tf$while_loop(...) while(as.logical(x))
for(x in tensor) tf$while_loop(...) while(!is.null(x <- iter_next(tensor_iterator)))
for(x in tfdataset) dataset_reduce() while(!is.null(x <- iter_next(dataset_iterator)))

Lets go through them one at a time.

if

In eager mode, if(eager_tensor) is translated to if(eager_tensor$`__bool__`()). (Equivalent to calling reticulate::py_bool(), and in essence a slightly more robust way to call as.logical(eager_tensor)).

In graph mode, if statements written in R automatically get translated to a tf.cond(). tf.cond() requires that both branches of the conditional are balanced (meaning, both branches return the same output structure). autograph() tries to capture all locally modified variables, newly created variables, as well as the return value of the overall expression in the translated tf.cond(), while satisfying the requirement for balanced branches.

tf_sign <- tf_function(function(x) {
  if (x > 0)
    1
  else if (x < 0)
    -1
  else
    0
})

Any variables that can’t be balanced between the branches are exported as undefined objects (S3 class: undef). Undefined objects throw an informative error if you attempt to access them. The error message indicates which expression the undef originated from, and suggestions for how to prevent the symbol from being an undef.

undef_example <- tf_function(function(x) {
  if (x > 0) {
    branch_local_tmp <- x + 1
    x <- branch_local_tmp + 1
    x
  }
  branch_local_tmp
})
undef_example(tf$constant(1))
## Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: Symbol `branch_local_tmp` is *undefined* after autographing the expression:
##      if (x > 0) {
##      branch_local_tmp <- x + 1
##      x <- branch_local_tmp + 1
##      x
##  }
## To access this symbol, Tensorflow requires that an object with the same dtype and shape be assigned to the symbol either before the `if` statement, or in all branches of the `if` statement.

while

In eager mode, while(eagor_tensor) is translated to while(as.logical(eager_tensor)). The tensorflow R package provides tensor methods for many S3 generics, including as.logical() which coerces an EagerTensor to an R logical atomic, so this works as you would expect.

Here is an example of an autographed while expression being evaluated eagerly. Remember, autograph() is not just for functions!

total <- 1
x
## tf.Tensor(5.0, shape=(), dtype=float32)
autograph({
  while (x != 0) {
    message("Evaluating while body R expression")
    total %<>% multiply_by(x)
    x %<>% subtract(tf_sign(x))
  }
})
## Evaluating while body R expression
## Evaluating while body R expression
## Evaluating while body R expression
## Evaluating while body R expression
## Evaluating while body R expression
x
## tf.Tensor(0.0, shape=(), dtype=float32)
total
## tf.Tensor(120.0, shape=(), dtype=float32)

In graph mode, while expressions are translated to a tf$while_loop() call.

tf_factorial <- tf_function(autograph(function(x) {
  total <- 1
  while (x != 0) {
    message("Evaluating while body R expression")
    total %<>% multiply_by(x)
    x %<>% subtract(tf_sign(x))
  }
  total
}))

tf_factorial(tf$constant(5))
## Evaluating while body R expression
## tf.Tensor(120.0, shape=(), dtype=float32)
tf_factorial(tf$constant(-5))
## tf.Tensor(-120.0, shape=(), dtype=float32)

tf.while_loop() has many options. In order to pass those through to the call, precede the while expression with ag_while_opts().

tf_factorial_v2 <- tf_function(function(x) {
  total <- as_tensor(1, dtype = x$dtype)
  ag_while_opts(
    shape_invariants = list(total = shape(), x = shape()),
    parallel_iterations = 1
  )
  while (x != 0) {
    message("Evaluating while body R expression")
    total %<>% multiply_by(x)
    x %<>% subtract(tf_sign(x))
  }
  total
})
tf_factorial_v2(tf$constant(5))
## Evaluating while body R expression
## tf.Tensor(120.0, shape=(), dtype=float32)

for

Autographed for loops build on top of while loops. autograph() adds support for three new types of values passed to for:

  • TF Datasets
  • Tensors
  • Python iterators (eager mode only).

In eager mode, both Datasets and Tensors are coerced to iterators (via reticulate::as_iterator()). The arguments are then iterated over until the iterable is finished. Essentially, a call like

for(elem in iterable) {...}

gets translated to

iterator <- as_iterator(iterable)
while(!iter_is_done(iterator)) {elem <- iter_next(iterator); ...}

Note, tensors are iterated over their first dimension.

m <- tf$convert_to_tensor(matrix(1:12, nrow = 3, byrow = TRUE))
m
## tf.Tensor(
## [[ 1  2  3  4]
##  [ 5  6  7  8]
##  [ 9 10 11 12]], shape=(3, 4), dtype=int32)
autograph({
  for (row in m)
    print(row)
})
## tf.Tensor([1 2 3 4], shape=(4), dtype=int32)
## tf.Tensor([5 6 7 8], shape=(4), dtype=int32)
## tf.Tensor([ 9 10 11 12], shape=(4), dtype=int32)

In graph mode, for can accept a Tensor or a Dataset.

niave_reduce_sum <- tf_function(function(x, dtype = "int64") {
  running_total <- tf$zeros(list(), dtype)
  for (elem in x)
    running_total %<>% add(elem)

  running_total
})

Works with a Tensor:

niave_reduce_sum(tf$range(10L, dtype = "int64"))
## tf.Tensor(45, shape=(), dtype=int64)

and with a Dataset:

niave_reduce_sum(tf$data$Dataset$range(10L))
## tf.Tensor(45, shape=(), dtype=int64)

Since for(var in tensor) loops are powered by tf$while_loop(), you can pass additional options via ag_while_opts() just as you would to an autographed while() expression.

niave_reduce_sum_with_opts <- tf_function(autograph(function(x) {
  running_total <- tf$zeros(list(), x$dtype)

  ag_while_opts(parallel_iterations = 1)
  for (elem in x)
    running_total %<>% add(elem)

  running_total
}))

niave_reduce_sum_with_opts(tf$range(10))
## tf.Tensor(45.0, shape=(), dtype=float32)

break / next

Loop control flow statements break and next are handled automatically by autograph(), both in eager mode and graph mode. Use break and/or next anywhere you would use it naturally in while and for loops.

FizzBuzz!

Lets tie some concepts together to write fizzbuzz! Before we do that, we’ll write a helper tf_print() that writes to a temporary file by default. This will help knitr capture the output in this quarto document. (If we don’t redirect output to a file, then it would show up in the rendering console and not in this vignette)

TEMPFILE <- tempfile("tf-print-out", fileext = ".txt")

print_tempfile <-
  function(clear_after_read = TRUE, rewrap_lines = TRUE) {
    output <- readLines(TEMPFILE, warn = FALSE)
    if (clear_after_read) unlink(TEMPFILE)
    if (rewrap_lines) output <- strwrap(paste0(output, collapse = " "))
    writeLines(output)
  }

tf_print <- function(...)
  tf$print(..., output_stream = sprintf("file://%s", TEMPFILE))
fizzbuzz <- autograph(function(n) {
  for (i in range_dataset(from = 1L, to = n)) {
    if (i %% 15L == 0L)
      tf_print("FizzBuzz")
    else if (i %% 3L == 0L)
      tf_print("Fizz")
    else if (i %% 5L == 0L)
     tf_print("Buzz")
    else
      tf_print(i)
  }
})

First, lets run it in eager mode.

fizzbuzz(tf$constant(25L))
print_tempfile()
## 1 2 Fizz 4 Buzz Fizz 7 8 Fizz Buzz 11 Fizz 13 14 FizzBuzz 16 17
## Fizz 19 Buzz Fizz 22 23 Fizz

And now in graph mode.

tf_fizzbuzz <- tf_function(fizzbuzz)
tf_fizzbuzz(tf$constant(25L))
print_tempfile()
## 1 2 Fizz 4 Buzz Fizz 7 8 Fizz Buzz 11 Fizz 13 14 FizzBuzz 16 17
## Fizz 19 Buzz Fizz 22 23 Fizz

Visualize Function Graphs

As you are writing tf.function()s, it’s helpful to visualize what the produced graph from a particular autographed function looks like. Use tfautograph::view_function_graph() to launch a tensorboard instance with the produced function graph from a temporary directory. Note, the function must be being traced for the first time for view_function_graph to succeed.

view_function_graph(tf_function(fizzbuzz), list(tf$constant(25L)))

FizzBuzz Graph in Tensorboard

Control Dependencies

Side effects Ops tf$print() and tf$Assert() created within a tf_function are executed in-line and in the correct order when the function is evaluated. (If you’re used to working Tensorflow version 1, you read that right!)

Tensorflow 2 has drastically changed (for the better!) how control dependencies work. For the most part, so long as you only enter graph mode while within a tf_function(), there is no need anymore to enter a control dependency context. The days of wrapping code in a with(tf$control_dependency(...), ...) are mostly over.

Growing Objects / TensorArrays

The package provides a [[<- method for TensorArrays. That’s the recommended way to grow objects on the graph. See ?`[[<-.tensorflow.python.ops.tensor_array_ops.TensorArray` for usage examples.

Other helpers

The tfautograph package provides a small collection of additional helpers when working with tensorflow from R. They are:

  • tf_assert(): A thin wrapper around tf.Assert() that automatically generates a useful error message that includes the R call stack and the values of the tensors involved in the assert expression.
  • tf_cond(), tf_case(), tf_switch(): Thin wrappers around the control flow primitives that accept compact (rlang style ~) lambda syntax for the callables.

How it works

autograph() works by evaluating expressions in an environment where primitives like if and for are masked by autographing versions of them. The complete list of which symbols are masked by autograph() is:

## [1] "if"        "while"     "for"       "break"     "next"      "stopifnot"
## [7] "on.exit"