Tạo một ứng dụng Android để phát hiện các đối tượng trong hình ảnh

1. Trước khi bắt đầu

Trong lớp học lập trình này, bạn sẽ tìm hiểu cách chạy suy luận phát hiện đối tượng từ một ứng dụng Android bằng TensorFlow Serving với REST và gRPC.

Điều kiện tiên quyết

  • Kiến thức cơ bản về cách phát triển Android bằng Java
  • Kiến thức cơ bản về học máy bằng TensorFlow, chẳng hạn như huấn luyện và triển khai
  • Kiến thức cơ bản về thiết bị đầu cuối và Docker

Kiến thức bạn sẽ học được

  • Cách tìm các mô hình phát hiện vật thể được huấn luyện trước trên TensorFlow Hub.
  • Cách tạo một ứng dụng Android đơn giản và đưa ra dự đoán bằng mô hình phát hiện đối tượng đã tải xuống thông qua TensorFlow Serving (REST và gRPC).
  • Cách hiển thị kết quả phát hiện trong giao diện người dùng.

Bạn cần có

2. Bắt đầu thiết lập

Cách tải mã xuống cho lớp học lập trình này:

  1. Chuyển đến kho lưu trữ GitHub cho lớp học lập trình này.
  2. Nhấp vào Code > Download zip (Mã > Tải xuống tệp ZIP) để tải tất cả mã cho lớp học lập trình này xuống.

a72f2bb4caa9a96.png

  1. Giải nén tệp zip đã tải xuống để giải nén một thư mục gốc codelabs có tất cả tài nguyên bạn cần.

Đối với lớp học lập trình này, bạn chỉ cần các tệp trong thư mục con TFServing/ObjectDetectionAndroid trong kho lưu trữ. Thư mục này chứa 2 thư mục:

  • Thư mục starter chứa mã khởi đầu mà bạn sẽ dùng để xây dựng trong lớp học lập trình này.
  • Thư mục finished chứa mã hoàn chỉnh cho ứng dụng mẫu đã hoàn thành.

3. Thêm các phần phụ thuộc vào dự án

Nhập ứng dụng khởi đầu vào Android Studio

  • Trong Android Studio, hãy nhấp vào File > New > Import project (Tệp > Mới > Nhập dự án) rồi chọn thư mục starter trong mã nguồn mà bạn đã tải xuống trước đó.

Thêm các phần phụ thuộc cho OkHttp và gRPC

  • Trong tệp app/build.gradle của dự án, hãy xác nhận sự hiện diện của các phần phụ thuộc.
dependencies {
  // ...
    implementation 'com.squareup.okhttp3:okhttp:4.9.0'
    implementation 'javax.annotation:javax.annotation-api:1.3.2'
    implementation 'io.grpc:grpc-okhttp:1.29.0'
    implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
    implementation 'io.grpc:grpc-stub:1.29.0'
}

Đồng bộ hoá dự án với các tệp Gradle

  • Chọn 541e90b497a7fef7.png Đồng bộ hoá dự án với tệp Gradle trong trình đơn điều hướng.

4. Chạy ứng dụng khởi đầu

Chạy và khám phá ứng dụng

Ứng dụng sẽ khởi chạy trên thiết bị Android của bạn. Giao diện người dùng khá đơn giản: có một hình ảnh con mèo mà bạn muốn phát hiện các đối tượng và người dùng có thể chọn cách gửi dữ liệu đến phần phụ trợ, bằng REST hoặc gRPC. Phần phụ trợ thực hiện quy trình phát hiện đối tượng trên hình ảnh và trả về kết quả phát hiện cho ứng dụng khách. Ứng dụng khách sẽ kết xuất lại giao diện người dùng.

24eab579530e9645.png

Hiện tại, nếu bạn nhấp vào Run inference (Chạy suy luận), thì sẽ không có gì xảy ra. Điều này là do ứng dụng chưa giao tiếp được với phần phụ trợ.

5. Triển khai mô hình phát hiện đối tượng bằng TensorFlow Serving

Phát hiện đối tượng là một nhiệm vụ học máy rất phổ biến và mục tiêu của nhiệm vụ này là phát hiện các đối tượng trong hình ảnh, cụ thể là dự đoán các danh mục có thể có của đối tượng và các khung hình chữ nhật xung quanh đối tượng. Dưới đây là ví dụ về kết quả phát hiện:

a68f9308fb2fc17b.png

Google đã xuất bản một số mô hình được huấn luyện trước trên TensorFlow Hub. Để xem danh sách đầy đủ, hãy truy cập vào trang object_detection. Bạn sử dụng mô hình SSD MobileNet V2 FPNLite 320x320 tương đối nhẹ cho lớp học lập trình này để không nhất thiết phải dùng GPU để chạy mô hình.

Cách triển khai mô hình phát hiện vật thể bằng TensorFlow Serving:

  1. Tải tệp mô hình xuống.
  2. Giải nén tệp .tar.gz đã tải xuống bằng một công cụ giải nén, chẳng hạn như 7-Zip.
  3. Tạo một thư mục ssd_mobilenet_v2_2_320 rồi tạo một thư mục con 123 bên trong thư mục đó.
  4. Đặt thư mục variables và tệp saved_model.pb đã giải nén vào thư mục con 123.

Bạn có thể tham khảo thư mục ssd_mobilenet_v2_2_320 dưới dạng thư mục SavedModel. 123 là một ví dụ về số phiên bản. Nếu muốn, bạn có thể chọn một số khác.

Cấu trúc thư mục sẽ có dạng như hình ảnh sau:

42c8150a42033767.png

Bắt đầu TensorFlow Serving

  • Trong thiết bị đầu cuối, hãy bắt đầu TensorFlow Serving bằng Docker, nhưng thay thế phần giữ chỗ PATH/TO/SAVEDMODEL bằng đường dẫn tuyệt đối của thư mục ssd_mobilenet_v2_2_320 trên máy tính.
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving

Trước tiên, Docker sẽ tự động tải hình ảnh TensorFlow Serving xuống. Quá trình này mất một phút. Sau đó, TensorFlow Serving sẽ bắt đầu. Nhật ký sẽ có dạng như đoạn mã sau:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6. Kết nối ứng dụng Android với TensorFlow Serving thông qua REST

Giờ đây, phần phụ trợ đã sẵn sàng, vì vậy bạn có thể gửi yêu cầu của ứng dụng đến TensorFlow Serving để phát hiện các đối tượng trong hình ảnh. Có hai cách để gửi yêu cầu đến TensorFlow Serving:

  • REST
  • gRPC

Gửi yêu cầu và nhận phản hồi thông qua REST

Bạn chỉ cần thực hiện 3 bước đơn giản:

  • Tạo yêu cầu REST.
  • Gửi yêu cầu REST đến TensorFlow Serving.
  • Trích xuất kết quả dự đoán từ phản hồi REST và hiển thị giao diện người dùng.

Bạn sẽ đạt được những mục tiêu này trong MainActivity.java.

Tạo yêu cầu REST

Hiện tại, có một hàm createRESTRequest() trống trong tệp MainActivity.java. Bạn triển khai hàm này để tạo một yêu cầu REST.

private Request createRESTRequest() {
}

TensorFlow Serving dự kiến sẽ có một yêu cầu POST chứa tensor hình ảnh cho mô hình SSD MobileNet mà bạn sử dụng. Vì vậy, bạn cần trích xuất các giá trị RGB từ mỗi điểm ảnh của hình ảnh vào một mảng, rồi bao bọc mảng đó trong một JSON. Đây là tải trọng của yêu cầu.

  • Thêm mã này vào hàm createRESTRequest():
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
    inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
    inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
    }
}

RequestBody requestBody =
    RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);

Request request =
    new Request.Builder()
        .url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
        .post(requestBody)
        .build();

return request;    

Gửi yêu cầu REST đến TensorFlow Serving

Ứng dụng này cho phép người dùng chọn REST hoặc gRPC để giao tiếp với TensorFlow Serving, vì vậy, có 2 nhánh trong trình nghe onClick(View view).

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
                // TODO: REST request
            }
            else {

            }
        }
    }
)
  • Thêm mã này vào nhánh REST của trình nghe onClick(View view) để sử dụng OkHttp nhằm gửi yêu cầu đến TensorFlow Serving:
// Send the REST request.
Request request = createRESTRequest();
try {
    client =
        new OkHttpClient.Builder()
            .connectTimeout(20, TimeUnit.SECONDS)
            .writeTimeout(20, TimeUnit.SECONDS)
            .readTimeout(20, TimeUnit.SECONDS)
            .callTimeout(20, TimeUnit.SECONDS)
            .build();
    Response response = client.newCall(request).execute();
    JSONObject responseObject = new JSONObject(response.body().string());
    postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

Xử lý phản hồi REST từ TensorFlow Serving

Mô hình SSD MobileNet trả về một số kết quả, bao gồm:

  • num_detections: số lần phát hiện
  • detection_scores: điểm phát hiện
  • detection_classes: chỉ mục lớp phát hiện
  • detection_boxes: toạ độ của hộp giới hạn

Bạn triển khai hàm postprocessRESTResponse() để xử lý phản hồi.

private void postprocessRESTResponse(Predict.PredictResponse response) {

}
  • Thêm mã này vào hàm postprocessRESTResponse():
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
    maxIndex =
        detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);

Giờ đây, hàm xử lý hậu kỳ sẽ trích xuất các giá trị được dự đoán từ phản hồi, xác định danh mục có khả năng xảy ra nhất của đối tượng và toạ độ của các đỉnh hộp giới hạn, sau cùng là hiển thị hộp giới hạn phát hiện trên giao diện người dùng.

Chạy ứng dụng

  1. Nhấp vào execute.png Chạy "ứng dụng" trong trình đơn điều hướng rồi đợi ứng dụng tải.
  2. Chọn REST > Run inference (REST > Chạy suy luận).

Mất vài giây trước khi ứng dụng kết xuất khung hình chữ nhật của con mèo và cho thấy 17 là danh mục của đối tượng, ánh xạ đến đối tượng cat trong tập dữ liệu COCO.

5a1a32768dc516d6.png

7. Kết nối ứng dụng Android với TensorFlow Serving thông qua gRPC

Ngoài REST, TensorFlow Serving còn hỗ trợ gRPC.

b6f4449c2c850b0e.png

gRPC là một khung Lệnh gọi thủ tục từ xa (RPC) hiện đại, nguồn mở, hiệu suất cao, có thể chạy trong mọi môi trường. Nền tảng này có thể kết nối hiệu quả các dịch vụ trong và trên các trung tâm dữ liệu với khả năng hỗ trợ có thể cắm cho việc cân bằng tải, theo dõi, kiểm tra tình trạng và xác thực. Theo quan sát, gRPC hoạt động hiệu quả hơn REST trong thực tế.

Gửi yêu cầu và nhận phản hồi bằng gRPC

Có 4 bước đơn giản:

  • [Không bắt buộc] Tạo mã giả lập ứng dụng gRPC.
  • Tạo yêu cầu gRPC.
  • Gửi yêu cầu gRPC đến TensorFlow Serving.
  • Trích xuất kết quả dự đoán từ phản hồi gRPC và hiển thị giao diện người dùng.

Bạn sẽ đạt được những mục tiêu này trong MainActivity.java.

Không bắt buộc: Tạo mã giả lập ứng dụng gRPC

Để sử dụng gRPC với TensorFlow Serving, bạn cần làm theo quy trình gRPC. Để tìm hiểu thêm về thông tin chi tiết, hãy xem tài liệu về gRPC.

a9d0e5cb543467b4.png

TensorFlow Serving và TensorFlow xác định các tệp .proto cho bạn. Kể từ TensorFlow và TensorFlow Serving 2.8, đây là những tệp .proto cần thiết:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
  • Để tạo phần khai báo, hãy thêm mã này vào tệp app/build.gradle.
apply plugin: 'com.google.protobuf'

protobuf {
    protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
    plugins {
        grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
        }
    }
    generateProtoTasks {
        all().each { task ->
            task.builtins {
                java { option 'lite' }
            }
            task.plugins {
                grpc { option 'lite' }
            }
        }
    }
}

Tạo yêu cầu gRPC

Tương tự như yêu cầu REST, bạn tạo yêu cầu gRPC trong hàm createGRPCRequest().

private Request createGRPCRequest() {

}
  • Thêm mã này vào hàm createGRPCRequest():
if (stub == null) {
  channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
  stub = PredictionServiceGrpc.newBlockingStub(channel);
}

Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);

Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);

TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored.
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
    tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
    tensorProtoBuilder.addIntVal((pixel) & 0xff);
    }
}
TensorProto tensorProto = tensorProtoBuilder.build();

builder.putInputs("input_tensor", tensorProto);

builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");

return builder.build();

Gửi yêu cầu gRPC đến TensorFlow Serving

Bây giờ, bạn có thể hoàn tất trình nghe onClick(View view).

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {

            }
            else {
                // TODO: gRPC request
            }
        }
    }
)
  • Thêm mã này vào nhánh gRPC:
try {
    Predict.PredictRequest request = createGRPCRequest();
    Predict.PredictResponse response = stub.predict(request);
    postprocessGRPCResponse(response);
} catch (Exception e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

Xử lý phản hồi gRPC từ TensorFlow Serving

Tương tự như gRPC, bạn triển khai hàm postprocessGRPCResponse() để xử lý phản hồi.

private void postprocessGRPCResponse(Predict.PredictResponse response) {

}
  • Thêm mã này vào hàm postprocessGRPCResponse():
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores =    response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
    maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass =    response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues =    response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);

Giờ đây, hàm xử lý hậu kỳ có thể trích xuất các giá trị dự đoán từ phản hồi và hiển thị khung phát hiện trong giao diện người dùng.

Chạy ứng dụng

  1. Nhấp vào execute.png Chạy "ứng dụng" trong trình đơn điều hướng rồi đợi ứng dụng tải.
  2. Chọn gRPC > Run inference (gRPC > Chạy suy luận).

Ứng dụng sẽ mất vài giây để kết xuất khung hình chữ nhật của con mèo và cho thấy 17 là danh mục của đối tượng, ánh xạ đến danh mục cat trong tập dữ liệu COCO.

8. Xin chúc mừng

Bạn đã sử dụng TensorFlow Serving để thêm khả năng phát hiện đối tượng vào ứng dụng của mình!

Tìm hiểu thêm