TensorFlow Hub with Keras

    TensorFlow Hub is a way to share pretrained model components. See the TensorFlow Module Hub for a searchable listing of pre-trained models. This tutorial demonstrates:

    1. How to use TensorFlow Hub with Keras.
    2. How to do image classification using TensorFlow Hub.
    3. How to do simple transfer learning.

    Setup

    library(keras)
    library(tfhub)
    library(magick)
    #> Linking to ImageMagick 6.9.9.39
    #> Enabled features: cairo, fontconfig, freetype, lcms, pango, rsvg, webp
    #> Disabled features: fftw, ghostscript, x11

    An ImageNet classifier

    Download the classifier

    Use layer_hub to load a mobilenet and transform it into a Keras layer. Any TensorFlow 2 compatible image classifier URL from tfhub.dev will work here.

    classifier_url <- "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" 
    mobilenet_layer <- layer_hub(handle = classifier_url)
    #> 
    #> Done!

    We can then create our Keras model:

    input <- layer_input(shape = c(224, 224, 3))
    output <- input %>% 
      mobilenet_layer()
    
    model <- keras_model(input, output)

    Run it on a single image

    Download a single image to try the model on.

    img <- image_read('https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg') %>%
      image_resize(geometry = "224x224x3!") %>% 
      image_data() %>% 
      as.numeric() %>% 
      abind::abind(along = 0) # expand to batch dimension

    result <- predict(model, img)
    mobilenet_decode_predictions(result[,-1, drop = FALSE])
    #> [[1]]
    #>   class_name class_description    score
    #> 1  n03763968  military_uniform 9.355025
    #> 2  n03787032       mortarboard 5.400680
    #> 3  n02817516          bearskin 5.297816
    #> 4  n04350905              suit 5.200010
    #> 5  n09835506        ballplayer 4.792098

    Simple transfer learning

    Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset.

    Dataset

    For this example you will use the TensorFlow flowers dataset:

    data_root <- pins::pin("https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", "flower_photos")
    data_root <- fs::path_dir(fs::path_dir(data_root[100])) # go down 2 levels

    The simplest way to load this data into our model is using image_data_generator

    All of TensorFlow Hub’s image modules expect float inputs in the [0, 1] range. Use the image_data_generator’s rescale parameter to achieve this.

    image_generator <- image_data_generator(rescale = 1/255, validation_split = 0.2)
    training_data <- flow_images_from_directory(
      directory = data_root, 
      generator = image_generator,
      target_size = c(224, 224), 
      subset = "training"
    )
    #> Found 2939 images belonging to 5 classes.
    
    validation_data <- flow_images_from_directory(
      directory = data_root, 
      generator = image_generator,
      target_size = c(224, 224), 
      subset = "validation"
    )
    #> Found 731 images belonging to 5 classes.

    The resulting object is an iterator that returns image_batch, label_batch pairs.

    Download the headless model

    TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

    Any Tensorflow 2 compatible image feature vector URL from tfhub.dev will work here.

    feature_extractor_url <- "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2"
    feature_extractor_layer <- layer_hub(handle = feature_extractor_url)

    Attach a classification head

    Now we can create our classification model by attaching a classification head into the feature extractor layer. We define the following model:

    input <- layer_input(shape = c(224, 224, 3))
    output <- input %>% 
      feature_extractor_layer() %>% 
      layer_dense(units = training_data$num_classes, activation = "softmax")
    
    model <- keras_model(input, output)
    summary(model)
    #> Model: "model_1"
    #> ________________________________________________________________________________
    #> Layer (type)                        Output Shape                    Param #     
    #> ================================================================================
    #> input_2 (InputLayer)                [(None, 224, 224, 3)]           0           
    #> ________________________________________________________________________________
    #> keras_layer_1 (KerasLayer)          (None, 1280)                    2257984     
    #> ________________________________________________________________________________
    #> dense (Dense)                       (None, 5)                       6405        
    #> ================================================================================
    #> Total params: 2,264,389
    #> Trainable params: 6,405
    #> Non-trainable params: 2,257,984
    #> ________________________________________________________________________________

    Train the model

    We can now train our model in the same way we would train any other Keras model. We first use compile to configure the training process:

    model %>% 
      compile(
        loss = "categorical_crossentropy",
        optimizer = "adam",
        metrics = "acc"
      )

    We can then use the fit function to fit our model.

    model %>% 
      fit_generator(
        training_data, 
        steps_per_epoch = training_data$n/training_data$batch_size,
        validation_data = validation_data
      )
    #> 
     1/91 [..............................] - ETA: 7:07 - loss: 1.8092 - acc: 0.2188
     2/91 [..............................] - ETA: 5:08 - loss: 1.8743 - acc: 0.1719
     3/91 [..............................] - ETA: 4:55 - loss: 1.8324 - acc: 0.1771
     4/91 [>.............................] - ETA: 4:29 - loss: 1.7727 - acc: 0.2188
     5/91 [>.............................] - ETA: 4:17 - loss: 1.7390 - acc: 0.2375
     6/91 [>.............................] - ETA: 4:02 - loss: 1.6711 - acc: 0.2812
     7/91 [=>............................] - ETA: 3:52 - loss: 1.6428 - acc: 0.2946
     8/91 [=>............................] - ETA: 3:42 - loss: 1.6052 - acc: 0.3242
     9/91 [=>............................] - ETA: 3:33 - loss: 1.5795 - acc: 0.3333
    10/91 [==>...........................] - ETA: 3:27 - loss: 1.5399 - acc: 0.3438
    11/91 [==>...........................] - ETA: 3:22 - loss: 1.5016 - acc: 0.3665
    12/91 [==>...........................] - ETA: 3:18 - loss: 1.4670 - acc: 0.3854
    13/91 [===>..........................] - ETA: 3:15 - loss: 1.4373 - acc: 0.4062
    14/91 [===>..........................] - ETA: 3:14 - loss: 1.3955 - acc: 0.4286
    15/91 [===>..........................] - ETA: 3:12 - loss: 1.3622 - acc: 0.4479
    16/91 [====>.........................] - ETA: 3:09 - loss: 1.3322 - acc: 0.4590
    17/91 [====>.........................] - ETA: 3:06 - loss: 1.3177 - acc: 0.4651
    18/91 [====>.........................] - ETA: 3:03 - loss: 1.2965 - acc: 0.4774
    19/91 [=====>........................] - ETA: 2:59 - loss: 1.2761 - acc: 0.4901
    20/91 [=====>........................] - ETA: 2:55 - loss: 1.2566 - acc: 0.4969
    21/91 [=====>........................] - ETA: 2:51 - loss: 1.2477 - acc: 0.5000
    22/91 [======>.......................] - ETA: 2:48 - loss: 1.2270 - acc: 0.5071
    23/91 [======>.......................] - ETA: 2:46 - loss: 1.2074 - acc: 0.5149
    24/91 [======>.......................] - ETA: 2:45 - loss: 1.1892 - acc: 0.5234
    25/91 [=======>......................] - ETA: 2:42 - loss: 1.1740 - acc: 0.5300
    26/91 [=======>......................] - ETA: 2:40 - loss: 1.1698 - acc: 0.5288
    27/91 [=======>......................] - ETA: 2:38 - loss: 1.1517 - acc: 0.5370
    28/91 [========>.....................] - ETA: 2:36 - loss: 1.1376 - acc: 0.5435
    29/91 [========>.....................] - ETA: 2:33 - loss: 1.1258 - acc: 0.5506
    30/91 [========>.....................] - ETA: 2:31 - loss: 1.1093 - acc: 0.5604
    31/91 [=========>....................] - ETA: 2:28 - loss: 1.0957 - acc: 0.5655
    32/91 [=========>....................] - ETA: 2:25 - loss: 1.0895 - acc: 0.5703
    33/91 [=========>....................] - ETA: 2:22 - loss: 1.0769 - acc: 0.5758
    34/91 [==========>...................] - ETA: 2:20 - loss: 1.0666 - acc: 0.5809
    35/91 [==========>...................] - ETA: 2:17 - loss: 1.0581 - acc: 0.5848
    36/91 [==========>...................] - ETA: 2:14 - loss: 1.0487 - acc: 0.5885
    37/91 [===========>..................] - ETA: 2:14 - loss: 1.0448 - acc: 0.5912
    38/91 [===========>..................] - ETA: 2:11 - loss: 1.0406 - acc: 0.5904
    39/91 [===========>..................] - ETA: 2:08 - loss: 1.0314 - acc: 0.5945
    40/91 [============>.................] - ETA: 2:05 - loss: 1.0197 - acc: 0.5992
    41/91 [============>.................] - ETA: 2:02 - loss: 1.0089 - acc: 0.6037
    42/91 [============>.................] - ETA: 2:00 - loss: 0.9983 - acc: 0.6102
    43/91 [=============>................] - ETA: 1:57 - loss: 0.9919 - acc: 0.6142
    44/91 [=============>................] - ETA: 1:54 - loss: 0.9812 - acc: 0.6187
    45/91 [=============>................] - ETA: 1:52 - loss: 0.9686 - acc: 0.6265
    46/91 [==============>...............] - ETA: 1:50 - loss: 0.9608 - acc: 0.6299
    47/91 [==============>...............] - ETA: 1:47 - loss: 0.9559 - acc: 0.6324
    48/91 [==============>...............] - ETA: 1:44 - loss: 0.9480 - acc: 0.6349
    49/91 [===============>..............] - ETA: 1:41 - loss: 0.9416 - acc: 0.6379
    50/91 [===============>..............] - ETA: 1:39 - loss: 0.9355 - acc: 0.6414
    51/91 [===============>..............] - ETA: 1:36 - loss: 0.9256 - acc: 0.6460
    52/91 [================>.............] - ETA: 1:33 - loss: 0.9165 - acc: 0.6498
    53/91 [================>.............] - ETA: 1:31 - loss: 0.9116 - acc: 0.6517
    54/91 [================>.............] - ETA: 1:28 - loss: 0.9029 - acc: 0.6547
    55/91 [=================>............] - ETA: 1:26 - loss: 0.8985 - acc: 0.6564
    56/91 [=================>............] - ETA: 1:23 - loss: 0.8906 - acc: 0.6603
    57/91 [=================>............] - ETA: 1:21 - loss: 0.8815 - acc: 0.6647
    58/91 [==================>...........] - ETA: 1:18 - loss: 0.8734 - acc: 0.6694
    59/91 [==================>...........] - ETA: 1:16 - loss: 0.8679 - acc: 0.6718
    60/91 [==================>...........] - ETA: 1:14 - loss: 0.8637 - acc: 0.6736
    61/91 [===================>..........] - ETA: 1:11 - loss: 0.8552 - acc: 0.6780
    62/91 [===================>..........] - ETA: 1:09 - loss: 0.8562 - acc: 0.6766
    63/91 [===================>..........] - ETA: 1:06 - loss: 0.8476 - acc: 0.6803
    64/91 [====================>.........] - ETA: 1:04 - loss: 0.8447 - acc: 0.6814
    65/91 [====================>.........] - ETA: 1:02 - loss: 0.8403 - acc: 0.6824
    66/91 [====================>.........] - ETA: 59s - loss: 0.8321 - acc: 0.6858 
    67/91 [=====================>........] - ETA: 57s - loss: 0.8249 - acc: 0.6891
    68/91 [=====================>........] - ETA: 55s - loss: 0.8221 - acc: 0.6909
    69/91 [=====================>........] - ETA: 53s - loss: 0.8202 - acc: 0.6913
    70/91 [======================>.......] - ETA: 50s - loss: 0.8141 - acc: 0.6944
    71/91 [======================>.......] - ETA: 48s - loss: 0.8113 - acc: 0.6952
    72/91 [======================>.......] - ETA: 46s - loss: 0.8052 - acc: 0.6981
    73/91 [=======================>......] - ETA: 44s - loss: 0.8003 - acc: 0.7006
    74/91 [=======================>......] - ETA: 41s - loss: 0.7977 - acc: 0.7025
    75/91 [=======================>......] - ETA: 39s - loss: 0.7908 - acc: 0.7052
    76/91 [========================>.....] - ETA: 37s - loss: 0.7844 - acc: 0.7087
    77/91 [========================>.....] - ETA: 34s - loss: 0.7796 - acc: 0.7109
    78/91 [========================>.....] - ETA: 32s - loss: 0.7757 - acc: 0.7122
    79/91 [=========================>....] - ETA: 29s - loss: 0.7704 - acc: 0.7142
    80/91 [=========================>....] - ETA: 27s - loss: 0.7683 - acc: 0.7143
    81/91 [=========================>....] - ETA: 25s - loss: 0.7649 - acc: 0.7151
    82/91 [==========================>...] - ETA: 22s - loss: 0.7597 - acc: 0.7171
    83/91 [==========================>...] - ETA: 20s - loss: 0.7551 - acc: 0.7194
    84/91 [==========================>...] - ETA: 17s - loss: 0.7523 - acc: 0.7201
    85/91 [===========================>..] - ETA: 15s - loss: 0.7490 - acc: 0.7215
    86/91 [===========================>..] - ETA: 12s - loss: 0.7444 - acc: 0.7233
    87/91 [===========================>..] - ETA: 9s - loss: 0.7397 - acc: 0.7258 
    88/91 [============================>.] - ETA: 7s - loss: 0.7385 - acc: 0.7264
    89/91 [============================>.] - ETA: 4s - loss: 0.7350 - acc: 0.7278
    90/91 [============================>.] - ETA: 2s - loss: 0.7313 - acc: 0.7290
    91/91 [==============================] - 239s 3s/step - loss: 0.7272 - acc: 0.7303 - val_loss: 0.4682 - val_acc: 0.8372

    You can then export your model with:

    save_model_tf(model, "model")

    You can also reload the model_from_saved_model function. Note that you need to pass the custom_object with the definition of the KerasLayer since it/s not a default Keras layer.

    reloaded_model <- load_model_tf("model")

    We can verify that the predictions of both the trained model and the reloaded model are equal:

    steps <- as.integer(validation_data$n/validation_data$batch_size)
    all.equal(
      predict_generator(model, validation_data, steps = steps),
      predict_generator(reloaded_model, validation_data, steps = steps),
    )
    #> [1] TRUE

    The saved model can also be loaded for inference later or be converted to TFLite or TFjs.