建立 Android 應用程式來偵測圖片中的物件

1. 事前準備

在本程式碼研究室中,您將瞭解如何透過 TensorFlow Serving 搭配 REST 和 gRPC 執行 Android 應用程式執行物件偵測推論。

必要條件

  • Java 的 Android 開發基本知識
  • TensorFlow 機器學習基本知識,例如訓練和部署
  • 終端機和 Docker 的基本知識

課程內容

  • 如何在 TensorFlow Hub 中尋找預先訓練的物件偵測模型。
  • 瞭解如何透過 TensorFlow Serving (REST 和 gRPC) 使用下載的物件偵測模型建構簡易型 Android 應用程式,並產生預測資料。
  • 如何在 UI 中呈現偵測結果。

軟硬體需求

2. 做好準備

若要下載此程式碼研究室的程式碼:

  1. 前往這項程式碼研究室的 GitHub 存放區
  2. 按一下 [程式碼 >下載 zip],即可下載這個程式碼研究室的所有程式碼。

a72f2bb4caa9a96.png

  1. 將下載的 ZIP 檔案解壓縮,解壓縮您需要的所有 codelabs 根資料夾。

在這個程式碼研究室中,您只需要存放區的 TFServing/ObjectDetectionAndroid 子目錄中的檔案,其中包含兩個資料夾:

  • starter 資料夾包含您為這個程式碼研究室建立的範例程式碼。
  • finished 資料夾包含已完成範例應用程式的範例程式碼。

3. 將依附元件加入專案

將入門應用程式匯入 Android Studio

  • 在 Android Studio 中,按一下 [File > New > Import project],然後從先前下載的原始碼中選擇 starter 資料夾。

新增 OkHttp 和 gRPC 的依附元件

  • 在專案的 app/build.gradle 檔案中,確認依附元件是否存在。
dependencies {
  // ...
    implementation 'com.squareup.okhttp3:okhttp:4.9.0'
    implementation 'javax.annotation:javax.annotation-api:1.3.2'
    implementation 'io.grpc:grpc-okhttp:1.29.0'
    implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
    implementation 'io.grpc:grpc-stub:1.29.0'
}

使用 Gradle 檔案同步處理專案

  • 從導覽選單中選取 541e90b497a7fef7.png [將專案與 Gradle 檔案同步處理]

4. 執行啟動應用程式

執行並探索應用程式

應用程式應該會在 Android 裝置上啟動。UI 非常簡單:你要用來偵測物件的貓咪圖片,使用者可以選擇使用 REST 或 gRPC 將資料傳送到後端的方式。後端會對圖片執行物件偵測,並將偵測結果傳回用戶端應用程式,進而再次顯示 UI。

24eab579530e9645.png

現在,如果您按一下 [執行推論],就不會發生任何動作。這是因為該執行個體尚未與後端通訊。

5. 使用 TensorFlow Serving 部署物件偵測模型

物件偵測是常見的機器學習工作,目標是要偵測圖片內的物件,也就是預測物件可能的類別和定界框。以下是偵測結果的範例:

a68f9308fb2fc17b.png

Google 已在 TensorFlow Hub 上發布了一些預先訓練模型。如要查看完整清單,請前往 object_detection 頁面。在這個程式碼研究室中,您可以使用相對輕量的 SSD MobileNet V2 FPNLite 320x320 模型,這樣就不需要使用 GPU 來執行模型。

如何使用 TensorFlow Serving 部署物件偵測模型:

  1. 下載模型檔案。
  2. 使用 7-Zip 等壓縮工具解壓縮已解壓縮的 .tar.gz 檔案。
  3. 建立 ssd_mobilenet_v2_2_320 資料夾,然後在其中建立 123 子資料夾。
  4. 將解壓縮的 variables 資料夾和 saved_model.pb 檔案放入 123 子資料夾。

ssd_mobilenet_v2_2_320 資料夾可視為 SavedModel 資料夾。123 是範例版本號碼。如有需要,您可以挑選其他號碼。

資料夾結構應如下列所示:

42c8150a42033767.png

啟動 TensorFlow Serving

  • 在終端機中,以 Docker 啟動 TensorFlow Serving,但將 PATH/TO/SAVEDMODEL 預留位置改成您電腦上 ssd_mobilenet_v2_2_320 資料夾的絕對路徑。
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving

Docker 會先自動下載 TensorFlow Serving 映像檔,這可能需要一分鐘的時間。之後,TensorFlow Serving 會隨即啟動。記錄應如下所示:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6. 透過 REST 將 Android 應用程式與 TensorFlow Serving 連結

後端已就緒,您可以傳送用戶端要求至 TensorFlow Serving,以偵測圖片中的物件。向 TensorFlow Serving 傳送要求的方法有兩種:

  • REST
  • gRPC

透過 REST 傳送要求及接收回應

這裡有三個簡單的步驟:

  • 建立 REST 要求。
  • 將 REST 要求傳送至 TensorFlow Serving。
  • 從 REST 回應中擷取預測結果並顯示 UI。

您將在 MainActivity.java.內達成這些目標

建立 REST 要求

目前,MainActivity.java 檔案中有空白的 createRESTRequest() 函式。您可以實作這個函式來建立 REST 要求。

private Request createRESTRequest() {
}

TensorFlow Serving 的 POST 要求包含您使用的 SSD

  • 將這段程式碼加進 createRESTRequest() 函式:
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
    inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
    inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
    }
}

RequestBody requestBody =
    RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);

Request request =
    new Request.Builder()
        .url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
        .post(requestBody)
        .build();

return request;    

將 REST 要求傳送至 TensorFlow Serving

應用程式可讓使用者選擇 REST 或 gRPC 與 TensorFlow Serving 通訊,因此 onClick(View view) 事件中會有兩個分支版本。

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
                // TODO: REST request
            }
            else {

            }
        }
    }
)
  • 將下列程式碼加進 onClick(View view) 監聽器的 REST 分支,以便使用 OkHttp 傳送要求至 TensorFlow Serving:
// Send the REST request.
Request request = createRESTRequest();
try {
    client =
        new OkHttpClient.Builder()
            .connectTimeout(20, TimeUnit.SECONDS)
            .writeTimeout(20, TimeUnit.SECONDS)
            .readTimeout(20, TimeUnit.SECONDS)
            .callTimeout(20, TimeUnit.SECONDS)
            .build();
    Response response = client.newCall(request).execute();
    JSONObject responseObject = new JSONObject(response.body().string());
    postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

處理 TensorFlow Serving 中的 REST 回應

SSD MobileNet 模型會傳回多項結果,包括:

  • num_detections:偵測次數
  • detection_scores:偵測分數
  • detection_classes:偵測類別索引
  • detection_boxes:定界框座標

實作 postprocessRESTResponse() 函式來處理回應。

private void postprocessRESTResponse(Predict.PredictResponse response) {

}
  • 將這段程式碼加進 postprocessRESTResponse() 函式:
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
    maxIndex =
        detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);

現在,後續處理函式會從回應中擷取預測值,找出物件最可能的類別和定界框頂點的座標,最後在 UI 中呈現偵測定界框。

執行

  1. 在導覽選單中,按一下 執行.png [執行應用程式],然後等待應用程式載入。
  2. 選取 [REST > Run 推測]

應用程式需要幾秒鐘的時間,才能顯示貓咪的定界框,並將 17 顯示為物件類別,而該物件會對應至 COCO 資料集中的 cat 物件。

5a1a32768dc516d6.png

7. 透過 gRPC 將 Android 應用程式與 TensorFlow Serving 連結

除了 REST 以外,TensorFlow Serving 也支援 gRPC

b6f4449c2c850b0e.png

gRPC 是現代化、開放原始碼的高效能遠端程序呼叫 (RPC) 架構,可在任何環境中執行。這項服務具備可連接的負載平衡、追蹤、健康狀態檢查和驗證等功能,可透過高效率的方式連結資料中心內外的服務。發現在實驗中,gRPC 的成效比 REST 更好。

使用 gRPC 傳送要求及接收回應

這裡有四個簡單的步驟:

  • [選用] 產生 gRPC 用戶端 stub 程式碼。
  • 建立 gRPC 要求。
  • 將 gRPC 要求傳送至 TensorFlow Serving。
  • 從 gRPC 回應中擷取預測結果,然後顯示 UI。

您將在 MainActivity.java.內達成這些目標

選用:產生 gRPC 用戶端 stub 程式碼

如要搭配 TensorFlow Serving 使用 gRPC,您需要遵守 gRPC 工作流程。如要進一步瞭解相關細節,請參閱 gRPC 說明文件

a9d0e5cb543467b4.png

TensorFlow Serving 和 TensorFlow 會為您定義 .proto 檔案。自 TensorFlow 和 TensorFlow Serving 2.8 起,下列 .proto 檔案為必要檔案:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
  • 如要產生 stub,請將這段程式碼加進 app/build.gradle 檔案中。
apply plugin: 'com.google.protobuf'

protobuf {
    protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
    plugins {
        grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
        }
    }
    generateProtoTasks {
        all().each { task ->
            task.builtins {
                java { option 'lite' }
            }
            task.plugins {
                grpc { option 'lite' }
            }
        }
    }
}

建立 gRPC 要求

與 REST 要求類似,您可以在 createGRPCRequest() 函式中建立 gRPC 要求。

private Request createGRPCRequest() {

}
  • 將這段程式碼加進 createGRPCRequest() 函式:
if (stub == null) {
  channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
  stub = PredictionServiceGrpc.newBlockingStub(channel);
}

Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);

Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);

TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored.
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
    tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
    tensorProtoBuilder.addIntVal((pixel) & 0xff);
    }
}
TensorProto tensorProto = tensorProtoBuilder.build();

builder.putInputs("input_tensor", tensorProto);

builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");

return builder.build();

將 gRPC 要求傳送至 TensorFlow Serving

您現在可以完成 onClick(View view) 事件監聽器。

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {

            }
            else {
                // TODO: gRPC request
            }
        }
    }
)
  • 將這段程式碼新增至 gRPC 分支版本:
try {
    Predict.PredictRequest request = createGRPCRequest();
    Predict.PredictResponse response = stub.predict(request);
    postprocessGRPCResponse(response);
} catch (Exception e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

處理 TensorFlow Serving 中的 gRPC 回應

與 gRPC 類似,您必須導入 postprocessGRPCResponse() 函式來處理回應。

private void postprocessGRPCResponse(Predict.PredictResponse response) {

}
  • 將這段程式碼加進 postprocessGRPCResponse() 函式:
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores =    response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
    maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass =    response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues =    response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);

現在,「後置處理」函式可以從回應中擷取預測值,然後在使用者介面中顯示偵測定界框。

執行

  1. 在導覽選單中,按一下 執行.png [執行應用程式],然後等待應用程式載入。
  2. 選取 [gRPC > Run 推測]

應用程式需要幾秒鐘的時間,才能顯示貓咪的定界框,並將 17 顯示為物件類別,而該類別會對應至 COCO 資料集中的 cat 類別。

8. 恭喜

您使用 TensorFlow Serving 在應用程式中新增物件偵測功能!

瞭解詳情