Overview

    This is a short introduction to Keras advanced features. It uses:

    1. tfdatasets to manage input data.
    2. A custom model.
    3. tfautograph for building a custom training loop.

    Before running the quickstart you need to have Keras installed. Please refer to the installation for installation instructions.

    library(keras)
    library(tfdatasets)
    library(tfautograph)
    library(reticulate)
    library(purrr)

    Let’s start by loading and preparing the MNIST dataset. The values of the pixels are integers between 0 and 255, and we will convert them to floats between 0 and 1.

    mnist <- dataset_mnist()
    mnist$train$x <- mnist$train$x/255
    mnist$test$x <- mnist$test$x/255
    
    dim(mnist$train$x) <- c(dim(mnist$train$x), 1)
    dim(mnist$test$x) <- c(dim(mnist$test$x), 1)

    Now let’s use tfdatasets to batch and shuffle the dataset.

    train_ds <- mnist$train %>% 
      tensor_slices_dataset() %>%
      dataset_take(20000) %>% 
      dataset_map(~modify_at(.x, "x", tf$cast, dtype = tf$float32)) %>% 
      dataset_map(~modify_at(.x, "y", tf$cast, dtype = tf$int64)) %>% 
      dataset_shuffle(10000) %>% 
      dataset_batch(32)
    
    test_ds <- mnist$test %>% 
      tensor_slices_dataset() %>% 
      dataset_take(2000) %>% 
      dataset_map(~modify_at(.x, "x", tf$cast, dtype = tf$float32)) %>%
      dataset_map(~modify_at(.x, "y", tf$cast, dtype = tf$int64)) %>% 
      dataset_batch(32)

    We will now define a Keras custom model.

    We can then choose an optimizer and loss function for training:

    Select metrics to measure the loss and the accuracy of the model. These metrics accumulate the values over epochs and then print the overall result.

    We then define a function that is able to make one training step:

    We then provide a function that is able to test the model:

    We can then write our training loop function:

    Finally let’s run our training loop for 5 epochs:

    ## Epoch:  1  -----------
    Acc 0.93095 Test Acc 0.954
    ## Epoch:  2  -----------
    Acc 0.956525 Test Acc 0.95825
    ## Epoch:  3  -----------
    Acc 0.968066692 Test Acc 0.9575
    ## Epoch:  4  -----------
    Acc 0.9752 Test Acc 0.960125
    ## Epoch:  5  -----------
    Acc 0.9796 Test Acc 0.9617