创建 Android 应用以检测图片中的对象

1. 准备工作

在此 Codelab 中,您将学习如何将 TensorFlow Serving 与 REST 和 gRPC 搭配使用,通过 Android 应用运行对象检测推断。

前提条件

  • 具备 Java 开发方面的基础知识
  • 具备使用 TensorFlow 进行机器学习的基础知识,例如训练和部署
  • 具备终端和 Docker 方面的基础知识

学习内容

  • 如何在 TensorFlow Hub 中查找预训练的对象检测模型。
  • 如何构建一个简单的 Android 应用,并通过 TensorFlow Serving(REST 和 gRPC)使用下载的对象检测模型进行预测。
  • 如何在界面中呈现检测结果。

所需物品

2. 进行设置

如需下载此 Codelab 的代码,请执行以下操作:

  1. 转到此 Codelab 的 GitHub 代码库
  2. 依次点击 Code > Download zip,下载此 Codelab 的所有代码。

a72f2bb4caa9a96.png

  1. 解压缩下载的 ZIP 文件,将包含您需要的所有资源的 codelabs 根文件夹解压缩。

在本 Codelab 中,您只需要代码库 TFServing/ObjectDetectionAndroid 子目录中的文件,其中包含两个文件夹:

  • starter 文件夹包含您基于此 Codelab 构建的起始代码。
  • finished 文件夹包含完成后的示例应用的完整代码。

3.将依赖项添加到项目中

将起始应用导入 Android Studio

  • 在 Android Studio 中,依次点击 File > New > Import project,然后从您之前下载的源代码中选择 starter 文件夹。

添加 OkHttp 和 gRPC 的依赖项

  • 在项目的 app/build.gradle 文件中,确认是否存在依赖项。
dependencies {
  // ...
    implementation 'com.squareup.okhttp3:okhttp:4.9.0'
    implementation 'javax.annotation:javax.annotation-api:1.3.2'
    implementation 'io.grpc:grpc-okhttp:1.29.0'
    implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
    implementation 'io.grpc:grpc-stub:1.29.0'
}

将项目与 Gradle 文件同步

  • 从导航菜单中选择 541e90b497a7fef7.png Sync Project with Gradle Files

4.运行入门级应用

运行和探索应用

应用应该会在 Android 设备上启动。界面非常简单:您可以使用一张猫的图片来检测对象,而用户可以通过 REST 或 gRPC 选择将数据发送到后端的方式。后端对图片执行对象检测,并将检测结果返回给客户端应用,然后客户端应用再次呈现界面。

24eab579530e9645.png

现在,如果您点击运行推断,将没有任何反应。这是因为它尚不能与后端通信。

5. 使用 TensorFlow Serving 部署对象检测模型

对象检测是一项常见的机器学习任务,其目标是检测图片中的对象,即预测对象可能的类别及其周围的边界框。以下是检测结果的示例:

a68f9308fb2fc17b.png

Google 在 TensorFlow Hub 上发布了许多预训练模型。要查看完整列表,请访问 object_detection 页面。在此 Codelab 中,您将使用相对轻量的 SSD MobileNet V2 FPNLite 320x320 模型,因此无需使用 GPU 即可运行该模型。

如需使用 TensorFlow Serving 部署对象检测模型,请执行以下操作:

  1. 下载模型文件。
  2. 使用解压缩工具(例如 7-Zip)解压缩下载的 .tar.gz 文件。
  3. 创建一个 ssd_mobilenet_v2_2_320 文件夹,然后在其中创建一个 123 子文件夹。
  4. 将提取的 variables 文件夹和 saved_model.pb 文件放入 123 子文件夹。

您可以将 ssd_mobilenet_v2_2_320 文件夹引用为 SavedModel 文件夹。123 是版本号示例。如有需要,您可以选择其他号码。

文件夹结构应如下图所示:

42c8150a42033767.png

启动 TensorFlow Serving

  • 在您的终端中,启动带 Docker 的 TensorFlow Serving,但将 PATH/TO/SAVEDMODEL 占位符替换为计算机上的 ssd_mobilenet_v2_2_320 文件夹的绝对路径。
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving

Docker 会先自动下载 TensorFlow Serving 映像,此过程需要一分钟时间。之后,TensorFlow Serving 应该就可以了。日志应类似于以下代码段:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6.通过 REST 将 Android 应用与 TensorFlow Serving 关联起来

现在后端已准备就绪,因此您可以向 TensorFlow Serving 发送客户端请求以检测图片中的对象。您可以通过以下两种方式向 TensorFlow Serving 发送请求:

  • REST
  • gRPC

通过 REST 发送请求和接收响应

有三个简单的步骤:

  • 创建 REST 请求。
  • 将 REST 请求发送到 TensorFlow Serving。
  • 从 REST 响应中提取预测结果,并呈现界面。

您将在 MainActivity.java. 年内实现这些目标

创建 REST 请求

现在,MainActivity.java 文件中有一个空的 createRESTRequest() 函数。您可以实现此函数来创建 REST 请求。

private Request createRESTRequest() {
}

TensorFlow Serving 需要一个 POST 请求,其中包含您使用的 SSD MobileNet 模型的图像张量,因此,您需要将图像的每个像素中的 RGB 值提取到一个数组中,然后将该数组封装在 JSON 中(载荷) 。

  • 将以下代码添加到 createRESTRequest() 函数中:
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
    inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
    inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
    }
}

RequestBody requestBody =
    RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);

Request request =
    new Request.Builder()
        .url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
        .post(requestBody)
        .build();

return request;

向 REST Serving 发送 REST 请求

该应用允许用户选择 REST 或 gRPC 与 TensorFlow Serving 通信,因此 onClick(View view) 监听器中有两个分支。

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
                // TODO: REST request
            }
            else {

            }
        }
    }
)
  • 将以下代码添加到 onClick(View view) 监听器的 REST 分支中,以使用 OkHttp 向 TensorFlow Serving 发送请求:
// Send the REST request.
Request request = createRESTRequest();
try {
    client =
        new OkHttpClient.Builder()
            .connectTimeout(20, TimeUnit.SECONDS)
            .writeTimeout(20, TimeUnit.SECONDS)
            .readTimeout(20, TimeUnit.SECONDS)
            .callTimeout(20, TimeUnit.SECONDS)
            .build();
    Response response = client.newCall(request).execute();
    JSONObject responseObject = new JSONObject(response.body().string());
    postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

处理来自 TensorFlow Serving 的 REST 响应

SSD MobileNet 模型会返回多个结果,其中包括:

  • num_detections:检测次数
  • detection_scores:检测分数
  • detection_classes:检测类索引
  • detection_boxes:边界框坐标

您实现 postprocessRESTResponse() 函数来处理响应。

private void postprocessRESTResponse(Predict.PredictResponse response) {

}
  • 将以下代码添加到 postprocessRESTResponse() 函数中:
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
    maxIndex =
        detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);

现在,后处理函数会从响应中提取预测值,识别对象最有可能的类别以及边界框顶点的坐标,最后在界面中呈现检测边界框。

运行应用

  1. 点击导航菜单中的 execute.png 运行“应用”,然后等待应用加载。
  2. 依次选择 REST > 运行推断

应用需要过几秒钟才能呈现猫的边界框,并且将 17 显示为该对象的类别,而该对象将映射到 COCO 数据集中的 cat 对象。

5a1a32768dc516d6.png

7. 通过 gRPC 将 Android 应用与 TensorFlow Serving 连接

除了 REST 之外,TensorFlow Serving 还支持 gRPC

b6f4449c2c850b0e.png

gRPC 是一种可在任何环境中运行的高性能开源远程过程调用 (RPC) 框架。它通过可插入式支持负载均衡、跟踪、健康检查和身份验证,高效地连接数据中心内和各个数据中心内的服务。我们发现,在实践中,gRPC 的性能比 REST 更高。

使用 gRPC 发送请求和接收响应

有四个简单的步骤:

  • [可选] 生成 gRPC 客户端桩代码。
  • 创建 gRPC 请求。
  • 将 gRPC 请求发送到 TensorFlow Serving。
  • 从 gRPC 响应中提取预测结果,并呈现界面。

您将在 MainActivity.java. 年内实现这些目标

可选:生成 gRPC 客户端桩代码

如需将 gRPC 与 TensorFlow Serving 搭配使用,您需要遵循 gRPC 工作流。如需了解详情,请参阅 gRPC 文档

a9d0e5cb543467b4.png

TensorFlow Serving 和 TensorFlow 会为您定义 .proto 文件。从 TensorFlow 和 TensorFlow Serving 2.8 开始,需要用到以下 .proto 文件:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
  • 如需生成桩,请将以下代码添加到 app/build.gradle 文件。
apply plugin: 'com.google.protobuf'

protobuf {
    protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
    plugins {
        grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
        }
    }
    generateProtoTasks {
        all().each { task ->
            task.builtins {
                java { option 'lite' }
            }
            task.plugins {
                grpc { option 'lite' }
            }
        }
    }
}

创建 gRPC 请求

与 REST 请求类似,您可以在 createGRPCRequest() 函数中创建 gRPC 请求。

private Request createGRPCRequest() {

}
  • 将以下代码添加到 createGRPCRequest() 函数中:
if (stub == null) {
  channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
  stub = PredictionServiceGrpc.newBlockingStub(channel);
}

Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);

Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);

TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored.
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
    tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
    tensorProtoBuilder.addIntVal((pixel) & 0xff);
    }
}
TensorProto tensorProto = tensorProtoBuilder.build();

builder.putInputs("input_tensor", tensorProto);

builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");

return builder.build();

向 TensorFlow Serving 发送 gRPC 请求

现在,您可以完成 onClick(View view) 监听器。

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {

            }
            else {
                // TODO: gRPC request
            }
        }
    }
)
  • 将以下代码添加到 gRPC 分支:
try {
    Predict.PredictRequest request = createGRPCRequest();
    Predict.PredictResponse response = stub.predict(request);
    postprocessGRPCResponse(response);
} catch (Exception e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

处理来自 TensorFlow Serving 的 gRPC 响应

与 gRPC 类似,您可以实现 postprocessGRPCResponse() 函数来处理响应。

private void postprocessGRPCResponse(Predict.PredictResponse response) {

}
  • 将以下代码添加到 postprocessGRPCResponse() 函数中:
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores =    response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
    maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass =    response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues =    response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);

现在,后处理函数可以从响应中提取预测值,并在界面中呈现检测边界框。

运行应用

  1. 点击导航菜单中的 execute.png 运行“应用”,然后等待应用加载。
  2. 依次选择 gRPC > 运行推断

应用需要过几秒钟才能呈现猫的边界框,并且将 17 显示为该对象的类别,而该对象将映射到 COCO 数据集中的 cat 类别。

8. 恭喜

您已使用 TensorFlow Serving 将对象检测功能添加到应用中!

了解详情