1. Trước khi bắt đầu
Trong lớp học lập trình này, bạn sẽ tìm hiểu cách chạy suy luận phát hiện đối tượng từ một ứng dụng Android bằng TensorFlow Serving với REST và gRPC.
Điều kiện tiên quyết
- Kiến thức cơ bản về cách phát triển Android bằng Java
- Kiến thức cơ bản về học máy bằng TensorFlow, chẳng hạn như huấn luyện và triển khai
- Kiến thức cơ bản về thiết bị đầu cuối và Docker
Kiến thức bạn sẽ học được
- Cách tìm các mô hình phát hiện vật thể được huấn luyện trước trên TensorFlow Hub.
- Cách tạo một ứng dụng Android đơn giản và đưa ra dự đoán bằng mô hình phát hiện đối tượng đã tải xuống thông qua TensorFlow Serving (REST và gRPC).
- Cách hiển thị kết quả phát hiện trong giao diện người dùng.
Bạn cần có
- Phiên bản mới nhất của Android Studio
- Docker
- Bash
2. Bắt đầu thiết lập
Cách tải mã xuống cho lớp học lập trình này:
- Chuyển đến kho lưu trữ GitHub cho lớp học lập trình này.
- Nhấp vào Code > Download zip (Mã > Tải xuống tệp ZIP) để tải tất cả mã cho lớp học lập trình này xuống.
- Giải nén tệp zip đã tải xuống để giải nén một thư mục gốc
codelabs
có tất cả tài nguyên bạn cần.
Đối với lớp học lập trình này, bạn chỉ cần các tệp trong thư mục con TFServing/ObjectDetectionAndroid
trong kho lưu trữ. Thư mục này chứa 2 thư mục:
- Thư mục
starter
chứa mã khởi đầu mà bạn sẽ dùng để xây dựng trong lớp học lập trình này. - Thư mục
finished
chứa mã hoàn chỉnh cho ứng dụng mẫu đã hoàn thành.
3. Thêm các phần phụ thuộc vào dự án
Nhập ứng dụng khởi đầu vào Android Studio
- Trong Android Studio, hãy nhấp vào File > New > Import project (Tệp > Mới > Nhập dự án) rồi chọn thư mục
starter
trong mã nguồn mà bạn đã tải xuống trước đó.
Thêm các phần phụ thuộc cho OkHttp và gRPC
- Trong tệp
app/build.gradle
của dự án, hãy xác nhận sự hiện diện của các phần phụ thuộc.
dependencies {
// ...
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'io.grpc:grpc-okhttp:1.29.0'
implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
implementation 'io.grpc:grpc-stub:1.29.0'
}
Đồng bộ hoá dự án với các tệp Gradle
- Chọn
Đồng bộ hoá dự án với tệp Gradle trong trình đơn điều hướng.
4. Chạy ứng dụng khởi đầu
- Khởi động Trình mô phỏng Android,sau đó nhấp vào
Chạy "ứng dụng" trong trình đơn điều hướng.
Chạy và khám phá ứng dụng
Ứng dụng sẽ khởi chạy trên thiết bị Android của bạn. Giao diện người dùng khá đơn giản: có một hình ảnh con mèo mà bạn muốn phát hiện các đối tượng và người dùng có thể chọn cách gửi dữ liệu đến phần phụ trợ, bằng REST hoặc gRPC. Phần phụ trợ thực hiện quy trình phát hiện đối tượng trên hình ảnh và trả về kết quả phát hiện cho ứng dụng khách. Ứng dụng khách sẽ kết xuất lại giao diện người dùng.
Hiện tại, nếu bạn nhấp vào Run inference (Chạy suy luận), thì sẽ không có gì xảy ra. Điều này là do ứng dụng chưa giao tiếp được với phần phụ trợ.
5. Triển khai mô hình phát hiện đối tượng bằng TensorFlow Serving
Phát hiện đối tượng là một nhiệm vụ học máy rất phổ biến và mục tiêu của nhiệm vụ này là phát hiện các đối tượng trong hình ảnh, cụ thể là dự đoán các danh mục có thể có của đối tượng và các khung hình chữ nhật xung quanh đối tượng. Dưới đây là ví dụ về kết quả phát hiện:
Google đã xuất bản một số mô hình được huấn luyện trước trên TensorFlow Hub. Để xem danh sách đầy đủ, hãy truy cập vào trang object_detection. Bạn sử dụng mô hình SSD MobileNet V2 FPNLite 320x320 tương đối nhẹ cho lớp học lập trình này để không nhất thiết phải dùng GPU để chạy mô hình.
Cách triển khai mô hình phát hiện vật thể bằng TensorFlow Serving:
- Tải tệp mô hình xuống.
- Giải nén tệp
.tar.gz
đã tải xuống bằng một công cụ giải nén, chẳng hạn như 7-Zip. - Tạo một thư mục
ssd_mobilenet_v2_2_320
rồi tạo một thư mục con123
bên trong thư mục đó. - Đặt thư mục
variables
và tệpsaved_model.pb
đã giải nén vào thư mục con123
.
Bạn có thể tham khảo thư mục ssd_mobilenet_v2_2_320
dưới dạng thư mục SavedModel
. 123
là một ví dụ về số phiên bản. Nếu muốn, bạn có thể chọn một số khác.
Cấu trúc thư mục sẽ có dạng như hình ảnh sau:
Bắt đầu TensorFlow Serving
- Trong thiết bị đầu cuối, hãy bắt đầu TensorFlow Serving bằng Docker, nhưng thay thế phần giữ chỗ
PATH/TO/SAVEDMODEL
bằng đường dẫn tuyệt đối của thư mụcssd_mobilenet_v2_2_320
trên máy tính.
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving
Trước tiên, Docker sẽ tự động tải hình ảnh TensorFlow Serving xuống. Quá trình này mất một phút. Sau đó, TensorFlow Serving sẽ bắt đầu. Nhật ký sẽ có dạng như đoạn mã sau:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle. 2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz 2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123 2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds. 2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests 2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123} 2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models 2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled 2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ... [warn] getaddrinfo: address family for nodename not supported 2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ... [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
6. Kết nối ứng dụng Android với TensorFlow Serving thông qua REST
Giờ đây, phần phụ trợ đã sẵn sàng, vì vậy bạn có thể gửi yêu cầu của ứng dụng đến TensorFlow Serving để phát hiện các đối tượng trong hình ảnh. Có hai cách để gửi yêu cầu đến TensorFlow Serving:
- REST
- gRPC
Gửi yêu cầu và nhận phản hồi thông qua REST
Bạn chỉ cần thực hiện 3 bước đơn giản:
- Tạo yêu cầu REST.
- Gửi yêu cầu REST đến TensorFlow Serving.
- Trích xuất kết quả dự đoán từ phản hồi REST và hiển thị giao diện người dùng.
Bạn sẽ đạt được những mục tiêu này trong MainActivity.java.
Tạo yêu cầu REST
Hiện tại, có một hàm createRESTRequest()
trống trong tệp MainActivity.java
. Bạn triển khai hàm này để tạo một yêu cầu REST.
private Request createRESTRequest() {
}
TensorFlow Serving dự kiến sẽ có một yêu cầu POST chứa tensor hình ảnh cho mô hình SSD MobileNet mà bạn sử dụng. Vì vậy, bạn cần trích xuất các giá trị RGB từ mỗi điểm ảnh của hình ảnh vào một mảng, rồi bao bọc mảng đó trong một JSON. Đây là tải trọng của yêu cầu.
- Thêm mã này vào hàm
createRESTRequest()
:
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
}
}
RequestBody requestBody =
RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);
Request request =
new Request.Builder()
.url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
.post(requestBody)
.build();
return request;
Gửi yêu cầu REST đến TensorFlow Serving
Ứng dụng này cho phép người dùng chọn REST hoặc gRPC để giao tiếp với TensorFlow Serving, vì vậy, có 2 nhánh trong trình nghe onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
// TODO: REST request
}
else {
}
}
}
)
- Thêm mã này vào nhánh REST của trình nghe
onClick(View view)
để sử dụng OkHttp nhằm gửi yêu cầu đến TensorFlow Serving:
// Send the REST request.
Request request = createRESTRequest();
try {
client =
new OkHttpClient.Builder()
.connectTimeout(20, TimeUnit.SECONDS)
.writeTimeout(20, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.callTimeout(20, TimeUnit.SECONDS)
.build();
Response response = client.newCall(request).execute();
JSONObject responseObject = new JSONObject(response.body().string());
postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Xử lý phản hồi REST từ TensorFlow Serving
Mô hình SSD MobileNet trả về một số kết quả, bao gồm:
num_detections
: số lần phát hiệndetection_scores
: điểm phát hiệndetection_classes
: chỉ mục lớp phát hiệndetection_boxes
: toạ độ của hộp giới hạn
Bạn triển khai hàm postprocessRESTResponse()
để xử lý phản hồi.
private void postprocessRESTResponse(Predict.PredictResponse response) {
}
- Thêm mã này vào hàm
postprocessRESTResponse()
:
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
maxIndex =
detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);
Giờ đây, hàm xử lý hậu kỳ sẽ trích xuất các giá trị được dự đoán từ phản hồi, xác định danh mục có khả năng xảy ra nhất của đối tượng và toạ độ của các đỉnh hộp giới hạn, sau cùng là hiển thị hộp giới hạn phát hiện trên giao diện người dùng.
Chạy ứng dụng
- Nhấp vào
Chạy "ứng dụng" trong trình đơn điều hướng rồi đợi ứng dụng tải.
- Chọn REST > Run inference (REST > Chạy suy luận).
Mất vài giây trước khi ứng dụng kết xuất khung hình chữ nhật của con mèo và cho thấy 17
là danh mục của đối tượng, ánh xạ đến đối tượng cat
trong tập dữ liệu COCO.
7. Kết nối ứng dụng Android với TensorFlow Serving thông qua gRPC
Ngoài REST, TensorFlow Serving còn hỗ trợ gRPC.
gRPC là một khung Lệnh gọi thủ tục từ xa (RPC) hiện đại, nguồn mở, hiệu suất cao, có thể chạy trong mọi môi trường. Nền tảng này có thể kết nối hiệu quả các dịch vụ trong và trên các trung tâm dữ liệu với khả năng hỗ trợ có thể cắm cho việc cân bằng tải, theo dõi, kiểm tra tình trạng và xác thực. Theo quan sát, gRPC hoạt động hiệu quả hơn REST trong thực tế.
Gửi yêu cầu và nhận phản hồi bằng gRPC
Có 4 bước đơn giản:
- [Không bắt buộc] Tạo mã giả lập ứng dụng gRPC.
- Tạo yêu cầu gRPC.
- Gửi yêu cầu gRPC đến TensorFlow Serving.
- Trích xuất kết quả dự đoán từ phản hồi gRPC và hiển thị giao diện người dùng.
Bạn sẽ đạt được những mục tiêu này trong MainActivity.java.
Không bắt buộc: Tạo mã giả lập ứng dụng gRPC
Để sử dụng gRPC với TensorFlow Serving, bạn cần làm theo quy trình gRPC. Để tìm hiểu thêm về thông tin chi tiết, hãy xem tài liệu về gRPC.
TensorFlow Serving và TensorFlow xác định các tệp .proto
cho bạn. Kể từ TensorFlow và TensorFlow Serving 2.8, đây là những tệp .proto
cần thiết:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
- Để tạo phần khai báo, hãy thêm mã này vào tệp
app/build.gradle
.
apply plugin: 'com.google.protobuf'
protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
}
}
generateProtoTasks {
all().each { task ->
task.builtins {
java { option 'lite' }
}
task.plugins {
grpc { option 'lite' }
}
}
}
}
Tạo yêu cầu gRPC
Tương tự như yêu cầu REST, bạn tạo yêu cầu gRPC trong hàm createGRPCRequest()
.
private Request createGRPCRequest() {
}
- Thêm mã này vào hàm
createGRPCRequest()
:
if (stub == null) {
channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
stub = PredictionServiceGrpc.newBlockingStub(channel);
}
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored.
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
tensorProtoBuilder.addIntVal((pixel) & 0xff);
}
}
TensorProto tensorProto = tensorProtoBuilder.build();
builder.putInputs("input_tensor", tensorProto);
builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");
return builder.build();
Gửi yêu cầu gRPC đến TensorFlow Serving
Bây giờ, bạn có thể hoàn tất trình nghe onClick(View view)
.
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
}
else {
// TODO: gRPC request
}
}
}
)
- Thêm mã này vào nhánh gRPC:
try {
Predict.PredictRequest request = createGRPCRequest();
Predict.PredictResponse response = stub.predict(request);
postprocessGRPCResponse(response);
} catch (Exception e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Xử lý phản hồi gRPC từ TensorFlow Serving
Tương tự như gRPC, bạn triển khai hàm postprocessGRPCResponse()
để xử lý phản hồi.
private void postprocessGRPCResponse(Predict.PredictResponse response) {
}
- Thêm mã này vào hàm
postprocessGRPCResponse()
:
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores = response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass = response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues = response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);
Giờ đây, hàm xử lý hậu kỳ có thể trích xuất các giá trị dự đoán từ phản hồi và hiển thị khung phát hiện trong giao diện người dùng.
Chạy ứng dụng
- Nhấp vào
Chạy "ứng dụng" trong trình đơn điều hướng rồi đợi ứng dụng tải.
- Chọn gRPC > Run inference (gRPC > Chạy suy luận).
Ứng dụng sẽ mất vài giây để kết xuất khung hình chữ nhật của con mèo và cho thấy 17
là danh mục của đối tượng, ánh xạ đến danh mục cat
trong tập dữ liệu COCO.
8. Xin chúc mừng
Bạn đã sử dụng TensorFlow Serving để thêm khả năng phát hiện đối tượng vào ứng dụng của mình!