How to use a pre-trained mode on an Android device

In previous chapter, we discussed how to train an object classifier using our own images. At the end, we got trained model and labels file (retrained_graph.pb, retrained_labels.txt).

In this chapter, We are going to load pre-trained classifer in our Android app. Unfortunately, we can not use the trained model on Android directly. We need to optimize it using a tool, namely “optimize_for_inference”, provided by Tensorflow.


1. Build tool “optimize_for_inference”

  • Download Tensorflow source code:

    git clone https://github.com/tensorflow/tensorflow.git
    
  • Install bazel so that we can use it to build “optimize_for_inference”

    echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list
    curl https://bazel.build/bazel-release.pub.gpg | sudo apt-key add -
    sudo apt-get update && sudo apt-get install bazel
    sudo apt-get upgrade bazel
    
  • Build “optimize_for_inference”

    cd tensorflow
    ./configure # We can choose all default settings
    bazel build tensorflow/python/tools:optimize_for_inference # this process takes a while, be patient
    

2. Optimize trained model

Let’s assume that our pre-trained model is <folder_path>/retrained_graph.pb. Then, we can use the following command to optimize the model and save it as <folder_path>/retrained_graph_android.pb

bazel-bin/tensorflow/python/tools/optimize_for_inference \
--input=<folder_path>/retrained_graph.pb \
--output=<folder_path>/retrained_graph_android.pb \
--input_names=Mul \
--output_names=final_result

3. Modify Tensorflow Android Demo

  • Download Android Demo:

    git clone https://github.com/Nilhcem/tensorflow-classifier-android.git
    

    If we import this project into Android Studio, compile and run, the demo will load a pre-trained classifier which can recognize 1000 classes.

  • Load our own model:

    1. Delete the previous ImageNet model from assets/ folder.
    2. Copy our optimized trained model retrained_graph_android.pb and label file retrained_labels.txt into assets/ folder.
    3. Open ClassifierActivity.java and set the following variables:
      private static final int INPUT_SIZE = 299; 
      private static final int IMAGE_MEAN = 128; 
      private static final float IMAGE_STD = 128; 
      private static final String INPUT_NAME = "Mul";
      private static final String OUTPUT_NAME = "final_result";
    
    1. Compile and run. The demo will open the camera and show the confidence score of each corresponding class.
Previous
Next