1. Prima di iniziare
In questo codelab, imparerai a eseguire un'inferenza di rilevamento degli oggetti da un'app per Android utilizzando TensorFlow Serving con REST e gRPC.
Prerequisiti
- Conoscenza di base dello sviluppo per Android con Java
- Conoscenza di base del machine learning con TensorFlow, ad esempio addestramento e deployment
- Conoscenza di base di terminali e Docker
Obiettivi didattici
- Come trovare modelli di rilevamento di oggetti preaddestrati su TensorFlow Hub.
- Come creare una semplice app per Android ed eseguire previsioni con il modello di rilevamento degli oggetti scaricato tramite TensorFlow Serving (REST e gRPC).
- Come visualizzare il risultato del rilevamento nell'UI.
Che cosa ti serve
- L'ultima versione di Android Studio
- Docker
- Bash
2. Configurazione
Per scaricare il codice per questo codelab:
- Vai al repository GitHub per questo codelab.
- Fai clic su Code > Download zip per scaricare tutto il codice di questo codelab.

- Decomprimi il file ZIP scaricato per estrarre una cartella principale
codelabscon tutte le risorse necessarie.
Per questo codelab, ti servono solo i file nella sottodirectory TFServing/ObjectDetectionAndroid del repository, che contiene due cartelle:
- La cartella
startercontiene il codice iniziale su cui si basa questo codelab. - La cartella
finishedcontiene il codice completato per l'app di esempio finita.
3. Aggiungi le dipendenze al progetto
Importa l'app iniziale in Android Studio
- In Android Studio, fai clic su File > Nuovo > Importa progetto e poi scegli la cartella
starterdal codice sorgente che hai scaricato in precedenza.
Aggiungi le dipendenze per OkHttp e gRPC
- Nel file
app/build.gradledel tuo progetto, verifica la presenza delle dipendenze.
dependencies {
// ...
implementation 'com.squareup.okhttp3:okhttp:4.9.0'
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'io.grpc:grpc-okhttp:1.29.0'
implementation 'io.grpc:grpc-protobuf-lite:1.29.0'
implementation 'io.grpc:grpc-stub:1.29.0'
}
Sincronizzare il progetto con i file Gradle
- Seleziona
Sincronizza progetto con file Gradle dal menu di navigazione.
4. Esegui l'app di base
- Avvia l'emulatore Android,quindi fai clic su
Esegui "app" nel menu di navigazione.
Eseguire ed esplorare l'app
L'app dovrebbe avviarsi sul tuo dispositivo Android. L'interfaccia utente è piuttosto semplice: c'è un'immagine di un gatto in cui vuoi rilevare gli oggetti e l'utente può scegliere il modo per inviare i dati al backend, con REST o gRPC. Il backend esegue il rilevamento degli oggetti nell'immagine e restituisce i risultati del rilevamento all'app client, che esegue di nuovo il rendering dell'interfaccia utente.

Al momento, se fai clic su Esegui inferenza, non succede nulla. Questo perché non può ancora comunicare con il backend.
5. Esegui il deployment di un modello di rilevamento degli oggetti con TensorFlow Serving
Il rilevamento degli oggetti è un'attività di machine learning molto comune e il suo scopo è rilevare gli oggetti all'interno delle immagini, ovvero prevedere le possibili categorie degli oggetti e i riquadri di delimitazione intorno a essi. Ecco un esempio di risultato del rilevamento:

Google ha pubblicato una serie di modelli preaddestrati su TensorFlow Hub. Per visualizzare l'elenco completo, visita la pagina object_detection. Per questo codelab utilizzi il modello SSD MobileNet V2 FPNLite 320x320 relativamente leggero, quindi non devi necessariamente utilizzare una GPU per eseguirlo.
Per eseguire il deployment del modello di rilevamento degli oggetti con TensorFlow Serving:
- Scarica il file del modello.
- Decomprimi il file
.tar.gzscaricato con uno strumento di decompressione, ad esempio 7-Zip. - Crea una cartella
ssd_mobilenet_v2_2_320e poi una sottocartella123al suo interno. - Inserisci la cartella
variablese il filesaved_model.pbestratti nella sottocartella123.
Puoi fare riferimento alla cartella ssd_mobilenet_v2_2_320 come cartella SavedModel. 123 è un numero di versione di esempio. Se vuoi, puoi scegliere un altro numero.
La struttura delle cartelle dovrebbe essere simile a quella nell'immagine:

Avvia TensorFlow Serving
- Nel terminale, avvia TensorFlow Serving con Docker, ma sostituisci il segnaposto
PATH/TO/SAVEDMODELcon il percorso assoluto della cartellassd_mobilenet_v2_2_320sul computer.
docker pull tensorflow/serving docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/ssd_mobilenet_v2_2" -e MODEL_NAME=ssd_mobilenet_v2_2 tensorflow/serving
Docker scarica automaticamente l'immagine TensorFlow Serving, il che richiede un minuto. Dopodiché, TensorFlow Serving dovrebbe avviarsi. Il log dovrebbe avere l'aspetto seguente:
2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: ssd_mobilenet_v2_2 version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
6. Collega l'app per Android a TensorFlow Serving tramite REST
Il backend è ora pronto, quindi puoi inviare richieste client a TensorFlow Serving per rilevare oggetti all'interno delle immagini. Esistono due modi per inviare richieste a TensorFlow Serving:
- REST
- gRPC
Inviare richieste e ricevere risposte tramite REST
Ci sono tre semplici passaggi:
- Crea la richiesta REST.
- Invia la richiesta REST a TensorFlow Serving.
- Estrai il risultato previsto dalla risposta REST e visualizza la UI.
Raggiungerai questi obiettivi tra MainActivity.java.
Crea la richiesta REST
Al momento, nel file MainActivity.java è presente una funzione createRESTRequest() vuota. Implementa questa funzione per creare una richiesta REST.
private Request createRESTRequest() {
}
TensorFlow Serving prevede una richiesta POST che contenga il tensore dell'immagine per il modello SSD MobileNet che utilizzi, quindi devi estrarre i valori RGB da ogni pixel dell'immagine in un array e poi racchiudere l'array in un JSON, che è il payload della richiesta.
- Aggiungi questo codice alla funzione
createRESTRequest():
//Create the REST request.
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
int[][][][] inputImgRGB = new int[1][INPUT_IMG_HEIGHT][INPUT_IMG_WIDTH][3];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
inputImgRGB[0][i][j][0] = ((pixel >> 16) & 0xff);
inputImgRGB[0][i][j][1] = ((pixel >> 8) & 0xff);
inputImgRGB[0][i][j][2] = ((pixel) & 0xff);
}
}
RequestBody requestBody =
RequestBody.create("{\"instances\": " + Arrays.deepToString(inputImgRGB) + "}", JSON);
Request request =
new Request.Builder()
.url("http://" + SERVER + ":" + REST_PORT + "/v1/models/" + MODEL_NAME + ":predict")
.post(requestBody)
.build();
return request;
Invia la richiesta REST a TensorFlow Serving
L'app consente all'utente di scegliere REST o gRPC per comunicare con TensorFlow Serving, quindi ci sono due rami nel listener onClick(View view).
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
// TODO: REST request
}
else {
}
}
}
)
- Aggiungi questo codice al ramo REST del listener
onClick(View view)per utilizzare OkHttp per inviare la richiesta a TensorFlow Serving:
// Send the REST request.
Request request = createRESTRequest();
try {
client =
new OkHttpClient.Builder()
.connectTimeout(20, TimeUnit.SECONDS)
.writeTimeout(20, TimeUnit.SECONDS)
.readTimeout(20, TimeUnit.SECONDS)
.callTimeout(20, TimeUnit.SECONDS)
.build();
Response response = client.newCall(request).execute();
JSONObject responseObject = new JSONObject(response.body().string());
postprocessRESTResponse(responseObject);
} catch (IOException | JSONException e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Elaborare la risposta REST di TensorFlow Serving
Il modello SSD MobileNet restituisce una serie di risultati, tra cui:
num_detections: il numero di rilevamentidetection_scores: punteggi di rilevamentodetection_classes: l'indice della classe di rilevamentodetection_boxes: le coordinate del riquadro di selezione
Implementa la funzione postprocessRESTResponse() per gestire la risposta.
private void postprocessRESTResponse(Predict.PredictResponse response) {
}
- Aggiungi questo codice alla funzione
postprocessRESTResponse():
// Process the REST response.
JSONArray predictionsArray = responseObject.getJSONArray("predictions");
//You only send one image, so you directly extract the first element.
JSONObject predictions = predictionsArray.getJSONObject(0);
// Argmax
int maxIndex = 0;
JSONArray detectionScores = predictions.getJSONArray("detection_scores");
for (int j = 0; j < predictions.getInt("num_detections"); j++) {
maxIndex =
detectionScores.getDouble(j) > detectionScores.getDouble(maxIndex + 1) ? j : maxIndex;
}
int detectionClass = predictions.getJSONArray("detection_classes").getInt(maxIndex);
JSONArray boundingBox = predictions.getJSONArray("detection_boxes").getJSONArray(maxIndex);
double ymin = boundingBox.getDouble(0);
double xmin = boundingBox.getDouble(1);
double ymax = boundingBox.getDouble(2);
double xmax = boundingBox.getDouble(3);
displayResult(detectionClass, (float) ymin, (float) xmin, (float) ymax, (float) xmax);
Ora la funzione di post-elaborazione estrae i valori previsti dalla risposta, identifica la categoria più probabile dell'oggetto e le coordinate dei vertici del riquadro di delimitazione e infine esegue il rendering del riquadro di delimitazione del rilevamento nell'interfaccia utente.
Esegui
- Fai clic su
Esegui "app" nel menu di navigazione e poi attendi il caricamento dell'app. - Seleziona REST > Run inference.
Sono necessari alcuni secondi prima che l'app esegua il rendering del riquadro di selezione del gatto e mostri 17 come categoria dell'oggetto, che corrisponde all'oggetto cat nel set di dati COCO.

7. Collega l'app per Android a TensorFlow Serving tramite gRPC
Oltre a REST, TensorFlow Serving supporta anche gRPC.

gRPC è un framework RPC (Remote Procedure Call) moderno, open source e ad alte prestazioni che può essere eseguito in qualsiasi ambiente. Può connettere in modo efficiente i servizi all'interno e tra i data center con supporto plug-in per bilanciamento del carico, tracciamento, controllo di integrità e autenticazione. È stato osservato che gRPC è più performante di REST nella pratica.
Inviare richieste e ricevere risposte con gRPC
Esistono quattro semplici passaggi:
- [Facoltativo] Genera il codice stub del client gRPC.
- Crea la richiesta gRPC.
- Invia la richiesta gRPC a TensorFlow Serving.
- Estrai il risultato previsto dalla risposta gRPC e visualizza la UI.
Raggiungerai questi obiettivi tra MainActivity.java.
(Facoltativo) Genera il codice dello stub client gRPC
Per utilizzare gRPC con TensorFlow Serving, devi seguire il flusso di lavoro gRPC. Per saperne di più, consulta la documentazione di gRPC.

TensorFlow Serving e TensorFlow definiscono i file .proto per te. A partire da TensorFlow e TensorFlow Serving 2.8, questi sono i file .proto necessari:
tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto
tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto
- Per generare lo stub, aggiungi questo codice al file
app/build.gradle.
apply plugin: 'com.google.protobuf'
protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.29.0'
}
}
generateProtoTasks {
all().each { task ->
task.builtins {
java { option 'lite' }
}
task.plugins {
grpc { option 'lite' }
}
}
}
}
Crea la richiesta gRPC
Analogamente alla richiesta REST, crei la richiesta gRPC nella funzione createGRPCRequest().
private Request createGRPCRequest() {
}
- Aggiungi questo codice alla funzione
createGRPCRequest():
if (stub == null) {
channel = ManagedChannelBuilder.forAddress(SERVER, GRPC_PORT).usePlaintext().build();
stub = PredictionServiceGrpc.newBlockingStub(channel);
}
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(MODEL_NAME);
modelSpecBuilder.setVersion(Int64Value.of(MODEL_VERSION));
modelSpecBuilder.setSignatureName(SIGNATURE_NAME);
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_HEIGHT));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(INPUT_IMG_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
int[] inputImg = new int[INPUT_IMG_HEIGHT * INPUT_IMG_WIDTH];
inputImgBitmap.getPixels(inputImg, 0, INPUT_IMG_WIDTH, 0, 0, INPUT_IMG_WIDTH, INPUT_IMG_HEIGHT);
int pixel;
for (int i = 0; i < INPUT_IMG_HEIGHT; i++) {
for (int j = 0; j < INPUT_IMG_WIDTH; j++) {
// Extract RBG values from each pixel; alpha is ignored.
pixel = inputImg[i * INPUT_IMG_WIDTH + j];
tensorProtoBuilder.addIntVal((pixel >> 16) & 0xff);
tensorProtoBuilder.addIntVal((pixel >> 8) & 0xff);
tensorProtoBuilder.addIntVal((pixel) & 0xff);
}
}
TensorProto tensorProto = tensorProtoBuilder.build();
builder.putInputs("input_tensor", tensorProto);
builder.addOutputFilter("num_detections");
builder.addOutputFilter("detection_boxes");
builder.addOutputFilter("detection_classes");
builder.addOutputFilter("detection_scores");
return builder.build();
Invia la richiesta gRPC a TensorFlow Serving
Ora puoi completare l'onClick(View view).
predictButton.setOnClickListener(
new View.OnClickListener() {
@Override
public void onClick(View view) {
if (requestRadioGroup.getCheckedRadioButtonId() == R.id.rest) {
}
else {
// TODO: gRPC request
}
}
}
)
- Aggiungi questo codice al ramo gRPC:
try {
Predict.PredictRequest request = createGRPCRequest();
Predict.PredictResponse response = stub.predict(request);
postprocessGRPCResponse(response);
} catch (Exception e) {
Log.e(TAG, e.getMessage());
responseTextView.setText(e.getMessage());
return;
}
Elaborare la risposta gRPC da TensorFlow Serving
Come per gRPC, implementa la funzione postprocessGRPCResponse() per gestire la risposta.
private void postprocessGRPCResponse(Predict.PredictResponse response) {
}
- Aggiungi questo codice alla funzione
postprocessGRPCResponse():
// Process the response.
float numDetections = response.getOutputsMap().get("num_detections").getFloatValList().get(0);
List<Float> detectionScores = response.getOutputsMap().get("detection_scores").getFloatValList();
int maxIndex = 0;
for (int j = 0; j < numDetections; j++) {
maxIndex = detectionScores.get(j) > detectionScores.get(maxIndex + 1) ? j : maxIndex;
}
Float detectionClass = response.getOutputsMap().get("detection_classes").getFloatValList().get(maxIndex);
List<Float> boundingBoxValues = response.getOutputsMap().get("detection_boxes").getFloatValList();
float ymin = boundingBoxValues.get(maxIndex * 4);
float xmin = boundingBoxValues.get(maxIndex * 4 + 1);
float ymax = boundingBoxValues.get(maxIndex * 4 + 2);
float xmax = boundingBoxValues.get(maxIndex * 4 + 3);
displayResult(detectionClass.intValue(), ymin, xmin, ymax, xmax);
Ora la funzione di post-elaborazione può estrarre i valori previsti dalla risposta e visualizzare il riquadro di selezione del rilevamento nella UI.
Esegui
- Fai clic su
Esegui "app" nel menu di navigazione e poi attendi il caricamento dell'app. - Seleziona gRPC > Run inference.
Occorrono alcuni secondi prima che l'app esegua il rendering del riquadro di selezione del gatto e mostri 17 come categoria dell'oggetto, che corrisponde alla categoria cat nel set di dati COCO.
8. Complimenti
Hai utilizzato TensorFlow Serving per aggiungere funzionalità di rilevamento degli oggetti alla tua app.