TensorFlow Hub with Keras

    TensorFlow Hub is a way to share pretrained model components. See the TensorFlow Module Hub for a searchable listing of pre-trained models. This tutorial demonstrates:

    1. How to use TensorFlow Hub with Keras.
    2. How to do image classification using TensorFlow Hub.
    3. How to do simple transfer learning.

    An ImageNet classifier

    Download the classifier

    Use layer_hub to load a mobilenet and transform it into a Keras layer. Any TensorFlow 2 compatible image classifier URL from tfhub.dev will work here.

    We can then create our Keras model:

    Run it on a single image

    Download a single image to try the model on.

    img <- image_read('https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg') %>%
      image_resize(geometry = "224x224x3!") %>% 
      image_data() %>% 
      as.numeric() %>% 
      abind::abind(along = 0) # expand to batch dimension

    Simple transfer learning

    Using TF Hub it is simple to retrain the top layer of the model to recognize the classes in our dataset.

    Dataset

    For this example you will use the TensorFlow flowers dataset:

    data_root <- pins::pin("https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", "flower_photos")
    data_root <- fs::path_dir(fs::path_dir(data_root[100])) # go down 2 levels

    The simplest way to load this data into our model is using image_data_generator

    All of TensorFlow Hub’s image modules expect float inputs in the [0, 1] range. Use the image_data_generator’s rescale parameter to achieve this.

    The resulting object is an iterator that returns image_batch, label_batch pairs.

    Download the headless model

    TensorFlow Hub also distributes models without the top classification layer. These can be used to easily do transfer learning.

    Any Tensorflow 2 compatible image feature vector URL from tfhub.dev will work here.

    Train the model

    We can now train our model in the same way we would train any other Keras model. We first use compile to configure the training process:

    We can then use the fit function to fit our model.

    model %>% 
      fit_generator(
        training_data, 
        steps_per_epoch = training_data$n/training_data$batch_size,
        validation_data = validation_data
      )
    #> 
     1/91 [..............................] - ETA: 9:22 - loss: 1.8289 - acc: 0.1562
     2/91 [..............................] - ETA: 6:13 - loss: 1.7180 - acc: 0.2344
     3/91 [..............................] - ETA: 5:06 - loss: 1.7136 - acc: 0.2396
     4/91 [>.............................] - ETA: 4:30 - loss: 1.7008 - acc: 0.2500
     5/91 [>.............................] - ETA: 4:08 - loss: 1.6716 - acc: 0.2750
     6/91 [>.............................] - ETA: 3:52 - loss: 1.6572 - acc: 0.2656
     7/91 [=>............................] - ETA: 3:41 - loss: 1.6237 - acc: 0.2768
     8/91 [=>............................] - ETA: 3:32 - loss: 1.5891 - acc: 0.3086
     9/91 [=>............................] - ETA: 3:30 - loss: 1.5525 - acc: 0.3299
    10/91 [==>...........................] - ETA: 3:24 - loss: 1.5259 - acc: 0.3406
    11/91 [==>...........................] - ETA: 3:17 - loss: 1.4849 - acc: 0.3665
    12/91 [==>...........................] - ETA: 3:12 - loss: 1.4630 - acc: 0.3776
    13/91 [===>..........................] - ETA: 3:06 - loss: 1.4352 - acc: 0.3870
    14/91 [===>..........................] - ETA: 3:02 - loss: 1.4056 - acc: 0.4018
    15/91 [===>..........................] - ETA: 2:57 - loss: 1.3743 - acc: 0.4250
    16/91 [====>.........................] - ETA: 2:55 - loss: 1.3439 - acc: 0.4453
    17/91 [====>.........................] - ETA: 2:53 - loss: 1.3304 - acc: 0.4522
    18/91 [====>.........................] - ETA: 2:50 - loss: 1.3070 - acc: 0.4670
    19/91 [=====>........................] - ETA: 2:46 - loss: 1.2821 - acc: 0.4770
    20/91 [=====>........................] - ETA: 2:43 - loss: 1.2497 - acc: 0.4969
    21/91 [=====>........................] - ETA: 2:40 - loss: 1.2324 - acc: 0.5074
    22/91 [======>.......................] - ETA: 2:38 - loss: 1.2143 - acc: 0.5170
    23/91 [======>.......................] - ETA: 2:35 - loss: 1.1879 - acc: 0.5312
    24/91 [======>.......................] - ETA: 2:32 - loss: 1.1681 - acc: 0.5430
    25/91 [=======>......................] - ETA: 2:29 - loss: 1.1487 - acc: 0.5487
    26/91 [=======>......................] - ETA: 2:26 - loss: 1.1355 - acc: 0.5529
    27/91 [=======>......................] - ETA: 2:24 - loss: 1.1264 - acc: 0.5544
    28/91 [========>.....................] - ETA: 2:21 - loss: 1.1106 - acc: 0.5625
    29/91 [========>.....................] - ETA: 2:18 - loss: 1.0968 - acc: 0.5700
    30/91 [========>.....................] - ETA: 2:16 - loss: 1.0842 - acc: 0.5792
    31/91 [=========>....................] - ETA: 2:14 - loss: 1.0771 - acc: 0.5806
    32/91 [=========>....................] - ETA: 2:11 - loss: 1.0612 - acc: 0.5889
    33/91 [=========>....................] - ETA: 2:09 - loss: 1.0499 - acc: 0.5947
    34/91 [==========>...................] - ETA: 2:07 - loss: 1.0324 - acc: 0.6039
    35/91 [==========>...................] - ETA: 2:05 - loss: 1.0129 - acc: 0.6125
    36/91 [==========>...................] - ETA: 2:03 - loss: 0.9963 - acc: 0.6198
    37/91 [===========>..................] - ETA: 2:01 - loss: 0.9928 - acc: 0.6191
    38/91 [===========>..................] - ETA: 1:58 - loss: 0.9773 - acc: 0.6258
    39/91 [===========>..................] - ETA: 1:56 - loss: 0.9667 - acc: 0.6314
    40/91 [============>.................] - ETA: 1:56 - loss: 0.9610 - acc: 0.6344
    41/91 [============>.................] - ETA: 1:55 - loss: 0.9471 - acc: 0.6395
    42/91 [============>.................] - ETA: 1:54 - loss: 0.9385 - acc: 0.6436
    43/91 [=============>................] - ETA: 1:53 - loss: 0.9311 - acc: 0.6461
    44/91 [=============>................] - ETA: 1:51 - loss: 0.9264 - acc: 0.6484
    45/91 [=============>................] - ETA: 1:50 - loss: 0.9178 - acc: 0.6528
    46/91 [==============>...............] - ETA: 1:48 - loss: 0.9158 - acc: 0.6535
    47/91 [==============>...............] - ETA: 1:45 - loss: 0.9099 - acc: 0.6562
    48/91 [==============>...............] - ETA: 1:43 - loss: 0.9000 - acc: 0.6608
    49/91 [===============>..............] - ETA: 1:41 - loss: 0.8963 - acc: 0.6626
    50/91 [===============>..............] - ETA: 1:38 - loss: 0.8885 - acc: 0.6662
    51/91 [===============>..............] - ETA: 1:36 - loss: 0.8813 - acc: 0.6691
    52/91 [================>.............] - ETA: 1:34 - loss: 0.8768 - acc: 0.6701
    53/91 [================>.............] - ETA: 1:32 - loss: 0.8667 - acc: 0.6739
    54/91 [================>.............] - ETA: 1:30 - loss: 0.8579 - acc: 0.6788
    55/91 [=================>............] - ETA: 1:27 - loss: 0.8513 - acc: 0.6830
    56/91 [=================>............] - ETA: 1:26 - loss: 0.8460 - acc: 0.6842
    57/91 [=================>............] - ETA: 1:23 - loss: 0.8370 - acc: 0.6880
    58/91 [==================>...........] - ETA: 1:21 - loss: 0.8307 - acc: 0.6913
    59/91 [==================>...........] - ETA: 1:19 - loss: 0.8258 - acc: 0.6944
    60/91 [==================>...........] - ETA: 1:16 - loss: 0.8183 - acc: 0.6984
    61/91 [===================>..........] - ETA: 1:14 - loss: 0.8135 - acc: 0.7003
    62/91 [===================>..........] - ETA: 1:11 - loss: 0.8081 - acc: 0.7021
    63/91 [===================>..........] - ETA: 1:08 - loss: 0.8028 - acc: 0.7049
    64/91 [====================>.........] - ETA: 1:06 - loss: 0.7993 - acc: 0.7061
    65/91 [====================>.........] - ETA: 1:03 - loss: 0.7946 - acc: 0.7072
    66/91 [====================>.........] - ETA: 1:01 - loss: 0.7920 - acc: 0.7079
    67/91 [=====================>........] - ETA: 58s - loss: 0.7860 - acc: 0.7099 
    68/91 [=====================>........] - ETA: 56s - loss: 0.7825 - acc: 0.7114
    69/91 [=====================>........] - ETA: 54s - loss: 0.7777 - acc: 0.7138
    70/91 [======================>.......] - ETA: 51s - loss: 0.7713 - acc: 0.7170
    71/91 [======================>.......] - ETA: 49s - loss: 0.7651 - acc: 0.7196
    72/91 [======================>.......] - ETA: 46s - loss: 0.7595 - acc: 0.7214
    73/91 [=======================>......] - ETA: 44s - loss: 0.7541 - acc: 0.7239
    74/91 [=======================>......] - ETA: 41s - loss: 0.7504 - acc: 0.7255
    75/91 [=======================>......] - ETA: 39s - loss: 0.7498 - acc: 0.7258
    76/91 [========================>.....] - ETA: 36s - loss: 0.7447 - acc: 0.7282
    77/91 [========================>.....] - ETA: 34s - loss: 0.7425 - acc: 0.7289
    78/91 [========================>.....] - ETA: 31s - loss: 0.7425 - acc: 0.7296
    79/91 [=========================>....] - ETA: 29s - loss: 0.7378 - acc: 0.7318
    80/91 [=========================>....] - ETA: 26s - loss: 0.7337 - acc: 0.7336
    81/91 [=========================>....] - ETA: 24s - loss: 0.7303 - acc: 0.7353
    82/91 [==========================>...] - ETA: 21s - loss: 0.7252 - acc: 0.7378
    83/91 [==========================>...] - ETA: 19s - loss: 0.7217 - acc: 0.7382
    84/91 [==========================>...] - ETA: 16s - loss: 0.7178 - acc: 0.7406
    85/91 [===========================>..] - ETA: 14s - loss: 0.7145 - acc: 0.7422
    86/91 [===========================>..] - ETA: 12s - loss: 0.7087 - acc: 0.7448
    87/91 [===========================>..] - ETA: 9s - loss: 0.7037 - acc: 0.7467 
    88/91 [============================>.] - ETA: 7s - loss: 0.7022 - acc: 0.7464
    89/91 [============================>.] - ETA: 4s - loss: 0.7004 - acc: 0.7474
    90/91 [============================>.] - ETA: 2s - loss: 0.7016 - acc: 0.7468
    91/91 [==============================] - 229s 3s/step - loss: 0.6974 - acc: 0.7485 - val_loss: 0.4758 - val_acc: 0.8263

    You can then export your model with:

    save_model_tf(model, "model")

    You can also reload the model_from_saved_model function. Note that you need to pass the custom_object with the definition of the KerasLayer since it/s not a default Keras layer.

    We can verify that the predictions of both the trained model and the reloaded model are equal:

    steps <- as.integer(validation_data$n/validation_data$batch_size)
    all.equal(
      predict_generator(model, validation_data, steps = steps),
      predict_generator(reloaded_model, validation_data, steps = steps),
    )
    #> [1] TRUE

    The saved model can also be loaded for inference later or be converted to TFLite or TFjs.