Tạo một ứng dụng Android để phát hiện đối tượng trong hình ảnh

1. Trước khi bắt đầu

Trong lớp học lập trình này, bạn tìm hiểu cách chạy suy luận phát hiện đối tượng từ một ứng dụng Android bằng cách sử dụng việc phân phát TensorFlow với REST và gRPC.

Điều kiện tiên quyết

  • Kiến thức cơ bản về việc phát triển Android thông qua Java
  • Kiến thức cơ bản về công nghệ máy học với TensorFlow, chẳng hạn như chương trình đào tạo và triển khai
  • Kiến thức cơ bản về thiết bị đầu cuối và Docker

Kiến thức bạn sẽ học được

  • Cách tìm các mô hình phát hiện đối tượng được đào tạo trước trên TensorFlow Hub.
  • Cách xây dựng một ứng dụng Android đơn giản và đưa ra dự đoán bằng mô hình phát hiện đối tượng được tải xuống thông qua việc phân phát TensorFlow (REST và gRPC).
  • Cách hiển thị kết quả phát hiện trong giao diện người dùng.

Bạn cần có

2. Bắt đầu thiết lập

Cách tải mã xuống cho lớp học lập trình này:

  1. Chuyển đến kho lưu trữ GitHub cho lớp học lập trình này.
  2. Nhấp vào Code > Download zip để tải tất cả mã xuống cho lớp học lập trình này.

a72f2bb4caa9a96.png

  1. Giải nén tệp zip đã tải xuống để giải nén thư mục gốc codelabs bằng tất cả tài nguyên mà bạn cần.

Đối với lớp học lập trình này, bạn chỉ cần các tệp trong thư mục con TFServing/ObjectDetectionAndroid trong kho lưu trữ, nơi chứa hai thư mục:

  • Thư mục starter chứa mã dành cho người mới bắt đầu mà bạn xây dựng cho lớp học lập trình này.
  • Thư mục finished chứa mã đã hoàn tất cho ứng dụng mẫu đã hoàn thành.

3. Thêm các phần phụ thuộc vào dự án

Nhập ứng dụng dành cho người mới bắt đầu vào Android Studio

  • Trong Android Studio, hãy nhấp vào Tệp > Mới > Nhập dự án, rồi chọn thư mục starter từ mã nguồn mà bạn đã tải xuống trước đó.

Thêm các phần phụ thuộc cho OkHttp và gRPC

  • Trong tệp app/build.gradle của dự án, hãy xác nhận sự hiện diện của các phần phụ thuộc.
dependencies {
  // ...
    implementation 'com.squareup.okhttp3:okhttp:4.9.0'
    implementation 'javax.annotation:javax.annotation-api:1.3.2'
    implementation 'io.grpc:grpc-okhttp:1.29.0'
    implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
    implementation 'io.grpc:grpc-stub:1.29.0'
}

Đồng bộ hóa dự án của bạn với các tệp cho Gradle

  • Chọn 541e90b497a7fef7.png Đồng bộ hóa dự án với tệp Gradle từ trình đơn điều hướng.

4. Chạy ứng dụng khởi động

Chạy và khám phá ứng dụng

Ứng dụng sẽ khởi chạy trên thiết bị Android của bạn. Giao diện người dùng khá đơn giản: có một hình ảnh mèo mà bạn muốn phát hiện các đối tượng và người dùng có thể chọn cách gửi dữ liệu tới phần phụ trợ, với REST hoặc gRPC. Phụ trợ này thực hiện chức năng phát hiện đối tượng trên hình ảnh và trả lại kết quả phát hiện cho ứng dụng khách, ứng dụng này sẽ hiển thị lại giao diện người dùng.

24eab579530e9645.png

Ngay bây giờ, nếu bạn nhấp vào Chạy dự đoán, thì sẽ không có gì xảy ra. Điều này là do nó không thể giao tiếp với chương trình phụ trợ.

5. Triển khai mô hình phát hiện đối tượng bằng tính năng Phân phát TensorFlow

Phát hiện đối tượng là một tác vụ máy học rất phổ biến và mục tiêu của nó là phát hiện các đối tượng trong hình ảnh, cụ thể là dự đoán các danh mục có thể có của các đối tượng và hộp giới hạn xung quanh chúng. Sau đây là ví dụ về kết quả phát hiện:

a68f9308fb2fc17b.png

Google đã phát hành một số mô hình được đào tạo trước trên TensorFlow Hub. Để xem danh sách đầy đủ, hãy truy cập vào trang object_detection. Bạn sử dụng mô hình SSD MobileNet V2 FPNLite 320x320 tương đối nhẹ cho lớp học lập trình này để bạn không nhất thiết phải sử dụng GPU để chạy mô hình đó.

Để triển khai mô hình phát hiện đối tượng bằng tính năng Phân phát TensorFlow:

  1. Tải tệp mô hình xuống.
  2. Giải nén tệp .tar.gz đã tải xuống bằng một công cụ giải nén, chẳng hạn như 7-zip.
  3. Tạo thư mục ssd_mobilenet_v2_2_320, sau đó tạo một thư mục con 123 bên trong thư mục đó.
  4. Đặt thư mục variables đã trích xuất và tệp saved_model.pb vào thư mục con 123.

Bạn có thể tham chiếu thư mục ssd_mobilenet_v2_2_320 dưới dạng thư mục SavedModel. 123 là số phiên bản mẫu. Nếu muốn, bạn có thể chọn một số khác.

Cấu trúc thư mục sẽ có dạng như sau:

42c8150a42033767.png

Bắt đầu phân phát TensorFlow

  • Trên thiết bị đầu cuối, hãy bắt đầu phân phát TensorFlow bằng Docker, nhưng thay thế trình giữ chỗ PATH/TO/SAVEDMODEL bằng đường dẫn tuyệt đối của thư mục ssd_mobilenet_v2_2_320 trên máy tính.
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving

Docker, hệ thống sẽ tự động tải hình ảnh Phân phát TensorFlow xuống trước. Quá trình này sẽ mất một phút. Sau đó, Quá trình phân phát TensorFlow sẽ bắt đầu. Nhật ký sẽ trông giống như đoạn mã này:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6. Kết nối ứng dụng Android với TensorFlow Delivery thông qua REST

Phần phụ trợ này hiện đã sẵn sàng, vì vậy bạn có thể gửi các yêu cầu ứng dụng đến TensorFlow Delivery để phát hiện các đối tượng trong hình ảnh. Có hai cách để gửi yêu cầu đến Phục vụ TensorFlow:

  • Kiến trúc chuyển trạng thái đại diện (REST)
  • gRPC

Gửi yêu cầu và nhận phản hồi qua REST

Có 3 bước đơn giản:

  • Tạo yêu cầu REST.
  • Gửi yêu cầu REST tới TensorFlowServe.
  • Trích xuất kết quả dự đoán từ phản hồi REST và hiển thị giao diện người dùng.

Bạn sẽ đạt được các mục này trong MainActivity.java.

Tạo yêu cầu REST

Hiện tại, có một hàm createRESTRequest() trống trong tệp MainActivity.java. Bạn triển khai hàm này để tạo một yêu cầu REST.

private Request createRESTRequest() {
}

TensorFlow Phục vụ yêu cầu một yêu cầu POST có chứa áp lực hình ảnh cho mô hình SSD MobileNet mà bạn sử dụng, vì vậy, bạn cần trích xuất các giá trị RGB từ mỗi pixel của hình ảnh vào một mảng, sau đó gói mảng vào một JSON, là tải trọng của yêu cầu.

  • Thêm mã này vào hàm createRESTRequest():
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
    inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
    inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
    }
}

RequestBody requestBody =
    RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);

Request request =
    new Request.Builder()
        .url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
        .post(requestBody)
        .build();

return request;    

Gửi yêu cầu REST đến Phục vụ TensorFlow

Ứng dụng này cho phép người dùng chọn REST hoặc gRPC để giao tiếp với TensorFlow Delivery, vì vậy có hai nhánh trong trình nghe onClick(View view).

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
                // TODO: REST request
            }
            else {

            }
        }
    }
)
  • Thêm mã này vào nhánh REST của trình nghe onClick(View view) để sử dụng OkHttp để gửi yêu cầu đến TensorFlowServe:
// Send the REST request.
Request request = createRESTRequest();
try {
    client =
        new OkHttpClient.Builder()
            .connectTimeout(20, TimeUnit.SECONDS)
            .writeTimeout(20, TimeUnit.SECONDS)
            .readTimeout(20, TimeUnit.SECONDS)
            .callTimeout(20, TimeUnit.SECONDS)
            .build();
    Response response = client.newCall(request).execute();
    JSONObject responseObject = new JSONObject(response.body().string());
    postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

Xử lý phản hồi REST (Phân phát) của TensorFlow

Mô hình SSD MobileNet trả về một số kết quả, bao gồm:

  • num_detections: số lượt phát hiện
  • detection_scores: điểm phát hiện
  • detection_classes: chỉ mục của lớp phát hiện
  • detection_boxes: tọa độ hộp giới hạn

Bạn triển khai hàm postprocessRESTResponse() để xử lý phản hồi.

private void postprocessRESTResponse(Predict.PredictResponse response) {

}
  • Thêm mã này vào hàm postprocessRESTResponse():
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
    maxIndex =
        detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);

Giờ đây, hàm xử lý hậu kỳ trích xuất các giá trị dự đoán từ phản hồi, xác định danh mục có thể xảy ra nhất của đối tượng và tọa độ của các đỉnh hộp giới hạn, cuối cùng hiển thị hộp giới hạn phát hiện trên giao diện người dùng.

Chạy

  1. Nhấp vào tức thì.png Chạy "ứng dụng\39"; trong trình đơn điều hướng rồi chờ ứng dụng tải.
  2. Chọn REST > Chạy dự đoán.

Ứng dụng này sẽ mất vài giây trước khi ứng dụng hiển thị hộp giới hạn của mèo và cho thấy 17 làm danh mục của đối tượng, liên kết tới đối tượng cat trong tập dữ liệu COCO.

5a1a32768dc516d6.png

7. Kết nối ứng dụng Android với TensorFlow Delivery thông qua gRPC

Ngoài REST, việc phân phát TensorFlow cũng hỗ trợ gRPC.

b6f4449c2c850b0e.png

gRPC là một khung lệnh gọi từ xa, nguồn mở hiện đại, hiệu suất cao, có thể chạy trong mọi môi trường. API này có thể kết nối các dịch vụ trong và trên các trung tâm dữ liệu một cách hiệu quả với khả năng hỗ trợ dễ dàng để cân bằng tải, theo dõi, kiểm tra tình trạng và xác thực. Theo quan sát, gRPC có hiệu suất cao hơn REST trong thực tế.

Gửi yêu cầu và nhận phản hồi bằng gRPC

Có 4 bước đơn giản:

  • [Không bắt buộc] Tạo mã mã ứng dụng gRPC.
  • Tạo yêu cầu gRPC.
  • Gửi yêu cầu gRPC đến TensorFlowServe.
  • Trích xuất kết quả dự đoán từ phản hồi gRPC và hiển thị giao diện người dùng.

Bạn sẽ đạt được các mục này trong MainActivity.java.

Không bắt buộc: Tạo mã mã ứng dụng gRPC

Để sử dụng gRPC với TensorFlowServe, bạn cần phải tuân theo quy trình làm việc của gRPC. Để tìm hiểu thêm về các chi tiết, hãy xem tài liệu gRPC.

a9d0e5cb543467b4.png

TensorFlow và TensorFlow xác định các tệp .proto cho bạn. Kể từ TensorFlow và TensorFlow phân phát 2.8, .proto tệp này là những tệp cần thiết:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
  • Để tạo mã giả, hãy thêm mã này vào tệp app/build.gradle.
apply plugin: 'com.google.protobuf'

protobuf {
    protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
    plugins {
        grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
        }
    }
    generateProtoTasks {
        all().each { task ->
            task.builtins {
                java { option 'lite' }
            }
            task.plugins {
                grpc { option 'lite' }
            }
        }
    }
}

Tạo yêu cầu gRPC

Tương tự như yêu cầu REST, bạn tạo yêu cầu gRPC trong hàm createGRPCRequest().

private Request createGRPCRequest() {

}
  • Thêm mã này vào hàm createGRPCRequest():
if (stub == null) {
  channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
  stub = PredictionServiceGrpc.newBlockingStub(channel);
}

Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);

Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);

TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
    for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
    // Extract RBG values from each pixel; alpha is ignored.
    pixel = inputImg[i * INPUT_IMG_WIDTH + j];
    tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
    tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
    tensorProtoBuilder.addIntVal((pixel) & 0xff);
    }
}
TensorProto tensorProto = tensorProtoBuilder.build();

builder.putInputs("input_tensor", tensorProto);

builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");

return builder.build();

Gửi yêu cầu gRPC đến TensorFlowServe

Giờ đây, bạn đã có thể nghe xong trình nghe onClick(View view).

predictButton.setOnClickListener(
    new View.OnClickListener() {
        @Override
        public void onClick(View view) {
            if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {

            }
            else {
                // TODO: gRPC request
            }
        }
    }
)
  • Thêm mã này vào chi nhánh gRPC:
try {
    Predict.PredictRequest request = createGRPCRequest();
    Predict.PredictResponse response = stub.predict(request);
    postprocessGRPCResponse(response);
} catch (Exception e) {
    Log.e(TAG, e.getMessage());
    responseTextView.setText(e.getMessage());
    return;
}

Xử lý phản hồi gRPC từ TensorFlowServe

Tương tự như gRPC, bạn triển khai hàm postprocessGRPCResponse() để xử lý phản hồi.

private void postprocessGRPCResponse(Predict.PredictResponse response) {

}
  • Thêm mã này vào hàm postprocessGRPCResponse():
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores =    response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
    maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass =    response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues =    response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);

Giờ đây, hàm xử lý hậu kỳ có thể trích xuất các giá trị dự đoán từ phản hồi và hiển thị hộp giới hạn phát hiện trong giao diện người dùng.

Chạy

  1. Nhấp vào tức thì.png Chạy "ứng dụng\39"; trong trình đơn điều hướng rồi chờ ứng dụng tải.
  2. Chọn gRPC > Chạy dự đoán.

Sẽ mất vài giây trước khi ứng dụng hiển thị hộp giới hạn của mèo và hiển thị 17 dưới dạng danh mục của đối tượng, liên kết tới danh mục cat trong tập dữ liệu COCO.

8. Xin chúc mừng

Bạn đã sử dụng tính năng Phân phát TensorFlow để thêm các tính năng phát hiện đối tượng vào ứng dụng của bạn!

Tìm hiểu thêm