TensorFlow Hub with Keras

    TensorFlow Hub is a way to share pretrained model components. See the TensorFlow Module Hub for a searchable listing of pre-trained models. This tutorial demonstrates:

    1. How to use TensorFlow Hub with Keras.
    2. How to do image classification using TensorFlow Hub.
    3. How to do simple transfer learning.

    An ImageNet classifier

    Download the classifier

    Use layer_hub to load a mobilenet and transform it into a Keras layer. Any TensorFlow 2 compatible image classifier URL from tfhub.dev will work here.

    We can then create our Keras model:

    Run it on a single image

    Download a single image to try the model on.

    img <- image_read('https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg') %>%
      image_resize(geometry = "224x224x3!") %>% 
      image_data() %>% 
      as.numeric() %>% 
      abind::abind(along = 0) # expand to batch dimension

    Simple transfer learning

    Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset.

    Dataset

    For this example you will use the TensorFlow flowers dataset:

    data_root <- pins::pin("https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", "flower_photos")
    data_root <- fs::path_dir(fs::path_dir(data_root[100])) # go down 2 levels

    The simplest way to load this data into our model is using image_data_generator

    All of TensorFlow Hub’s image modules expect float inputs in the [0, 1] range. Use the image_data_generator’s rescale parameter to achieve this.

    The resulting object is an iterator that returns image_batch, label_batch pairs.

    Download the headless model

    TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

    Any Tensorflow 2 compatible image feature vector URL from tfhub.dev will work here.

    Train the model

    We can now train our model in the same way we would train any other Keras model. We first use compile to configure the training process:

    We can then use the fit function to fit our model.

    model %>% 
      fit_generator(
        training_data, 
        steps_per_epoch = training_data$n/training_data$batch_size,
        validation_data = validation_data
      )
    #> 
     1/91 [..............................] - ETA: 7:43 - loss: 2.1885 - acc: 0.0938
     2/91 [..............................] - ETA: 5:15 - loss: 2.0069 - acc: 0.1094
     3/91 [..............................] - ETA: 4:25 - loss: 1.9867 - acc: 0.1354
     4/91 [>.............................] - ETA: 4:05 - loss: 1.9045 - acc: 0.1875
     5/91 [>.............................] - ETA: 3:54 - loss: 1.8909 - acc: 0.1813
     6/91 [>.............................] - ETA: 3:44 - loss: 1.8464 - acc: 0.2135
     7/91 [=>............................] - ETA: 3:39 - loss: 1.8301 - acc: 0.2143
     8/91 [=>............................] - ETA: 3:33 - loss: 1.8044 - acc: 0.2305
     9/91 [=>............................] - ETA: 3:29 - loss: 1.7647 - acc: 0.2465
    10/91 [==>...........................] - ETA: 3:25 - loss: 1.7356 - acc: 0.2531
    11/91 [==>...........................] - ETA: 3:19 - loss: 1.6866 - acc: 0.2869
    12/91 [==>...........................] - ETA: 3:15 - loss: 1.6622 - acc: 0.2891
    13/91 [===>..........................] - ETA: 3:10 - loss: 1.6108 - acc: 0.3101
    14/91 [===>..........................] - ETA: 3:06 - loss: 1.5742 - acc: 0.3304
    15/91 [===>..........................] - ETA: 3:03 - loss: 1.5437 - acc: 0.3438
    16/91 [====>.........................] - ETA: 2:59 - loss: 1.5095 - acc: 0.3613
    17/91 [====>.........................] - ETA: 2:57 - loss: 1.4814 - acc: 0.3768
    18/91 [====>.........................] - ETA: 2:53 - loss: 1.4502 - acc: 0.3958
    19/91 [=====>........................] - ETA: 2:54 - loss: 1.4188 - acc: 0.4096
    20/91 [=====>........................] - ETA: 2:50 - loss: 1.3900 - acc: 0.4252
    21/91 [=====>........................] - ETA: 2:47 - loss: 1.3709 - acc: 0.4378
    22/91 [======>.......................] - ETA: 2:44 - loss: 1.3473 - acc: 0.4492
    23/91 [======>.......................] - ETA: 2:41 - loss: 1.3189 - acc: 0.4610
    24/91 [======>.......................] - ETA: 2:38 - loss: 1.3045 - acc: 0.4666
    25/91 [=======>......................] - ETA: 2:35 - loss: 1.2859 - acc: 0.4755
    26/91 [=======>......................] - ETA: 2:33 - loss: 1.2738 - acc: 0.4800
    27/91 [=======>......................] - ETA: 2:31 - loss: 1.2602 - acc: 0.4878
    28/91 [========>.....................] - ETA: 2:29 - loss: 1.2408 - acc: 0.4972
    29/91 [========>.....................] - ETA: 2:26 - loss: 1.2179 - acc: 0.5081
    30/91 [========>.....................] - ETA: 2:23 - loss: 1.1976 - acc: 0.5204
    31/91 [=========>....................] - ETA: 2:20 - loss: 1.1760 - acc: 0.5329
    32/91 [=========>....................] - ETA: 2:18 - loss: 1.1686 - acc: 0.5358
    33/91 [=========>....................] - ETA: 2:15 - loss: 1.1546 - acc: 0.5423
    34/91 [==========>...................] - ETA: 2:13 - loss: 1.1491 - acc: 0.5439
    35/91 [==========>...................] - ETA: 2:11 - loss: 1.1340 - acc: 0.5507
    36/91 [==========>...................] - ETA: 2:08 - loss: 1.1229 - acc: 0.5562
    37/91 [===========>..................] - ETA: 2:06 - loss: 1.1088 - acc: 0.5615
    38/91 [===========>..................] - ETA: 2:05 - loss: 1.0952 - acc: 0.5681
    39/91 [===========>..................] - ETA: 2:02 - loss: 1.0840 - acc: 0.5720
    40/91 [============>.................] - ETA: 1:59 - loss: 1.0724 - acc: 0.5796
    41/91 [============>.................] - ETA: 1:57 - loss: 1.0613 - acc: 0.5845
    42/91 [============>.................] - ETA: 1:54 - loss: 1.0492 - acc: 0.5900
    43/91 [=============>................] - ETA: 1:52 - loss: 1.0356 - acc: 0.5966
    44/91 [=============>................] - ETA: 1:49 - loss: 1.0220 - acc: 0.6037
    45/91 [=============>................] - ETA: 1:46 - loss: 1.0064 - acc: 0.6111
    46/91 [==============>...............] - ETA: 1:44 - loss: 1.0006 - acc: 0.6142
    47/91 [==============>...............] - ETA: 1:41 - loss: 0.9903 - acc: 0.6211
    48/91 [==============>...............] - ETA: 1:39 - loss: 0.9805 - acc: 0.6244
    49/91 [===============>..............] - ETA: 1:36 - loss: 0.9726 - acc: 0.6289
    50/91 [===============>..............] - ETA: 1:34 - loss: 0.9612 - acc: 0.6339
    51/91 [===============>..............] - ETA: 1:31 - loss: 0.9535 - acc: 0.6374
    52/91 [================>.............] - ETA: 1:29 - loss: 0.9497 - acc: 0.6401
    53/91 [================>.............] - ETA: 1:27 - loss: 0.9453 - acc: 0.6422
    54/91 [================>.............] - ETA: 1:24 - loss: 0.9357 - acc: 0.6460
    55/91 [=================>............] - ETA: 1:22 - loss: 0.9308 - acc: 0.6467
    56/91 [=================>............] - ETA: 1:19 - loss: 0.9227 - acc: 0.6503
    57/91 [=================>............] - ETA: 1:17 - loss: 0.9162 - acc: 0.6537
    58/91 [==================>...........] - ETA: 1:15 - loss: 0.9113 - acc: 0.6553
    59/91 [==================>...........] - ETA: 1:12 - loss: 0.9075 - acc: 0.6569
    60/91 [==================>...........] - ETA: 1:10 - loss: 0.9023 - acc: 0.6595
    61/91 [===================>..........] - ETA: 1:07 - loss: 0.8948 - acc: 0.6615
    62/91 [===================>..........] - ETA: 1:05 - loss: 0.8868 - acc: 0.6650
    63/91 [===================>..........] - ETA: 1:03 - loss: 0.8791 - acc: 0.6683
    64/91 [====================>.........] - ETA: 1:00 - loss: 0.8710 - acc: 0.6725
    65/91 [====================>.........] - ETA: 58s - loss: 0.8655 - acc: 0.6747 
    66/91 [====================>.........] - ETA: 56s - loss: 0.8592 - acc: 0.6777
    67/91 [=====================>........] - ETA: 53s - loss: 0.8548 - acc: 0.6793
    68/91 [=====================>........] - ETA: 51s - loss: 0.8478 - acc: 0.6826
    69/91 [=====================>........] - ETA: 49s - loss: 0.8447 - acc: 0.6841
    70/91 [======================>.......] - ETA: 47s - loss: 0.8385 - acc: 0.6872
    71/91 [======================>.......] - ETA: 44s - loss: 0.8362 - acc: 0.6890
    72/91 [======================>.......] - ETA: 42s - loss: 0.8314 - acc: 0.6916
    73/91 [=======================>......] - ETA: 40s - loss: 0.8262 - acc: 0.6928
    74/91 [=======================>......] - ETA: 37s - loss: 0.8196 - acc: 0.6953
    75/91 [=======================>......] - ETA: 35s - loss: 0.8133 - acc: 0.6981
    76/91 [========================>.....] - ETA: 33s - loss: 0.8097 - acc: 0.6992
    77/91 [========================>.....] - ETA: 31s - loss: 0.8061 - acc: 0.7003
    78/91 [========================>.....] - ETA: 28s - loss: 0.7994 - acc: 0.7033
    79/91 [=========================>....] - ETA: 26s - loss: 0.7976 - acc: 0.7035
    80/91 [=========================>....] - ETA: 24s - loss: 0.7943 - acc: 0.7045
    81/91 [=========================>....] - ETA: 22s - loss: 0.7919 - acc: 0.7058
    82/91 [==========================>...] - ETA: 19s - loss: 0.7873 - acc: 0.7071
    83/91 [==========================>...] - ETA: 17s - loss: 0.7837 - acc: 0.7088
    84/91 [==========================>...] - ETA: 15s - loss: 0.7823 - acc: 0.7097
    85/91 [===========================>..] - ETA: 13s - loss: 0.7782 - acc: 0.7112
    86/91 [===========================>..] - ETA: 11s - loss: 0.7736 - acc: 0.7139
    87/91 [===========================>..] - ETA: 8s - loss: 0.7712 - acc: 0.7150 
    88/91 [============================>.] - ETA: 6s - loss: 0.7683 - acc: 0.7168
    89/91 [============================>.] - ETA: 4s - loss: 0.7640 - acc: 0.7186
    90/91 [============================>.] - ETA: 2s - loss: 0.7579 - acc: 0.7214
    91/91 [==============================] - 211s 2s/step - loss: 0.7551 - acc: 0.7231 - val_loss: 0.5031 - val_acc: 0.8057

    You can then export your model with:

    save_model_tf(model, "model")

    You can also reload the model_from_saved_model function. Note that you need to pass the custom_object with the definition of the KerasLayer since it/s not a default Keras layer.

    We can verify that the predictions of both the trained model and the reloaded model are equal:

    steps <- as.integer(validation_data$n/validation_data$batch_size)
    all.equal(
      predict_generator(model, validation_data, steps = steps),
      predict_generator(reloaded_model, validation_data, steps = steps),
    )
    #> [1] TRUE

    The saved model can also be loaded for inference later or be converted to TFLite or TFjs.