建立 Flutter 應用程式來將文字分類

1. 事前準備

在本程式碼研究室中,您將瞭解如何使用 TensorFlow 和 gRPC,透過 TensorFlow Serving 從 Flutter 應用程式執行文字分類推論。

必要條件

課程內容

  • 瞭解如何使用 TensorFlow Serving (REST 和 gRPC) 建構簡單的 Flutter 應用程式並將文字分類。
  • 如何在使用者介面中顯示結果。

軟硬體需求

2. 設定 Flutter 開發環境

如要進行 Flutter 開發,您必須安裝兩個軟體:Flutter SDK編輯器

您可以透過下列任一裝置執行程式碼研究室:

  • iOS 模擬工具 (必須安裝 Xcode 工具)。
  • Android 模擬器 (需在 Android Studio 中設定)。
  • 瀏覽器 (需要安裝 Chrome 才能進行偵錯)。
  • 提供 WindowsLinuxmacOS 電腦版應用程式。您必須在預計部署的平台上進行開發。因此,如果您要開發 Windows 桌面應用程式,就必須在 Windows 上開發,才能存取適當的建構鏈。如要瞭解作業系統相關規定,請參閱 docs.flutter.dev/desktop

3. 做好準備

若要下載此程式碼研究室的程式碼:

  1. 前往這項程式碼研究室的 GitHub 存放區
  2. 按一下 [程式碼 >下載 zip],即可下載這個程式碼研究室的所有程式碼。

2cd45599f51fb8a2.png

  1. 將下載的 ZIP 檔案解壓縮,解壓縮您需要的所有 codelabs-main 根資料夾。

在這個程式碼研究室中,您只需要存放區的 tfserving-flutter/codelab2 子目錄中的檔案,其中包含兩個資料夾:

  • starter 資料夾包含您為這個程式碼研究室建立的範例程式碼。
  • finished 資料夾包含已完成範例應用程式的範例程式碼。

4. 下載專案的依附元件

  1. 在 VS Code 中,按一下 [File > OpenFolder],然後從先前下載的原始碼中選取 starter 資料夾。
  2. 如果畫面上出現對話方塊,提示您下載啟動程式應用程式所需的套件,請按一下 [取得套件]
  3. 如果您沒有看到這個對話方塊,請開啟終端機,然後在 starter 資料夾中執行 flutter pub get 指令。

7ada07c300f166a6.png

5. 執行啟動應用程式

  1. 在 VS Code 中,確認 Android Emulator 或 iOS 模擬工具已正確設定並顯示在狀態列中。

舉例來說,以下為搭配 Android Emulator 使用 Pixel 5 時所顯示的內容:

9767649231898791.png

透過 iOS 模擬器使用 iPhone 13 時會看到下列資訊:

95529e3a682268b2.png

  1. 點選 a19a0c68bc4046e6.png [開始偵錯]

執行並探索應用程式

應用程式應該會在 Android Emulator 或 iOS 模擬工具上啟動。使用者介面相當簡單,這裡有文字欄位,可讓使用者輸入文字。使用者可以選擇要使用 REST 或 gRPC 將資料傳送到後端。後端使用 TensorFlow 模型對預先處理的輸入執行文字分類,並將分類結果傳回用戶端應用程式,進而更新使用者介面。

b298f605d64dc132.png d3ef3ccd3c338108.png

假如您按一下 [分類],系統就沒有任何作用,因為這個代理程式尚未與後端通訊。

6. 使用 TensorFlow Serving 部署文字分類模型

文字分類是一種常見的機器學習工作,可將文字分類到預先定義的類別。在這個程式碼研究室中,您將使用 TensorFlow Serving,透過使用 TensorFlow Lite Model Maker 程式碼研究室訓練垃圾留言偵測模型部署預先訓練的模型,並呼叫 Flutter 前端的後端,將輸入的文字分類為垃圾內容非垃圾內容

啟動 TensorFlow Serving

  • 在終端機中,以 Docker 啟動 TensorFlow Serving,但將 PATH/TO/SAVEDMODEL 預留位置改成您電腦上 mm_spam_savedmodel 資料夾的絕對路徑。
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving

Docker 會先自動下載 TensorFlow Serving 映像檔,這可能需要一分鐘的時間。之後,TensorFlow Serving 會隨即啟動。記錄應如下所示:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: spam-detection version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

7. 將輸入語句化

後端已就緒,因此你幾乎準備好傳送用戶端要求到 TensorFlow Serving,但必須先將輸入語句化代碼。當您檢查模型的輸入張量時,會發現模型預期會有 20 個整數,而非原始字串。「憑證化」是指您根據詞彙字典,將應用程式中的個別字詞對應至整數清單,然後再傳送至後端進行分類。舉例來說,如果您輸入 buy book online to learn more,代碼化程序就會對應至 [32, 79, 183, 10, 224, 631, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]。具體數字會因詞彙字典而異。

  1. lib/main.dart 檔案中,將此程式碼加入 predict() 方法以建立 _vocabMap 詞彙字典。
// Build _vocabMap if empty.
if (_vocabMap.isEmpty) {
  final vocabFileString = await rootBundle.loadString(vocabFile);
  final lines = vocabFileString.split('\n');
  for (final l in lines) {
    if (l != "") {
      var wordAndIndex = l.split(' ');
      (_vocabMap)[wordAndIndex[0]] = int.parse(wordAndIndex[1]);
    }
  }
} 
  1. 加入前一個程式碼片段後,立即加入這段程式碼以導入代碼化:
// Tokenize the input sentence.
final inputWords = _inputSentenceController.text
    .toLowerCase()
    .replaceAll(RegExp('[^a-z ]'), '')
    .split(' ');
// Initialize with padding token.
_tokenIndices = List.filled(maxSentenceLength, 0);
var i = 0;
for (final w in inputWords) {
  if ((_vocabMap).containsKey(w)) {
    _tokenIndices[i] = (_vocabMap)[w]!;
    i++;
  }

  // Truncate the string if longer than maxSentenceLength.
  if (i >= maxSentenceLength - 1) {
    break;
  }
}

此程式碼會使句子字串小寫,移除非字母字元,並根據字彙表格將字詞對應到 20 個整數索引。

8. 透過 REST 將 Flutter 應用程式與 TensorFlow Serving 連結

向 TensorFlow Serving 傳送要求的方法有兩種:

  • REST
  • gRPC

透過 REST 傳送要求及接收回應

透過 REST 傳送要求及接收回應的三個簡單步驟如下:

  1. 建立 REST 要求。
  2. 將 REST 要求傳送至 TensorFlow Serving。
  3. 從 REST 回應中擷取預測結果並顯示 UI。

您已完成 main.dart 檔案中的步驟。

建立 REST 要求並傳送至 TensorFlow Serving

  1. predict() 函式目前無法向 REST Serving 傳送 REST 要求。您必須導入 REST 分支來建立 REST 要求:
if (_connectionMode == ConnectionModeType.rest) {
  // TODO: Create and send the REST request.

}
  1. 將下列程式碼新增至 REST 分支版本:
//Create the REST request.
final response = await http.post(
  Uri.parse('http://' +
      _server +
      ':' +
      restPort.toString() +
      '/v1/models/' +
      modelName +
      ':predict'),
  body: jsonEncode(<String, List<List<int>>>{
    'instances': [_tokenIndices],
  }),
);

處理 TensorFlow Serving 中的 REST 回應

  • 加入下列程式碼,以處理 REST 回應:
// Process the REST response.
if (response.statusCode == 200) {
  Map<String, dynamic> result = jsonDecode(response.body);
  if (result['predictions']![0][1] >= classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        result['predictions']![0][1].toString();
  }
  return 'This sentence is not spam. Spam score is ' +
      result['predictions']![0][1].toString();
} else {
  throw Exception('Error response');
}

後續處理代碼會擷取輸入語句為垃圾訊息的機率,並在使用者介面中顯示分類結果。

執行

  1. 按一下 a19a0c68bc4046e6.png [開始偵錯],然後等待應用程式載入。
  2. 輸入文字,然後選取 [REST > Classify]

8e21d795af36d07a.png e79a0367a03c2169.png

9. 透過 gRPC 連結 Flutter 應用程式與 TensorFlow Serving

除了 REST 以外,TensorFlow Serving 也支援 gRPC

b6f4449c2c850b0e.png

gRPC 是現代化、開放原始碼的高效能遠端程序呼叫 (RPC) 架構,可在任何環境中執行。這項服務具備可連接的負載平衡、追蹤、健康狀態檢查和驗證等功能,可透過高效率的方式連結資料中心內外的服務。發現在實驗中,gRPC 的成效比 REST 更好。

使用 gRPC 傳送要求及接收回應

透過 gRPC 傳送要求並接收回應的四個簡單步驟:

  1. 選用:產生 gRPC 用戶端 stub 程式碼。
  2. 建立 gRPC 要求。
  3. 將 gRPC 要求傳送至 TensorFlow Serving。
  4. 從 gRPC 回應中擷取預測結果,然後顯示 UI。

您已完成 main.dart 檔案中的步驟。

選用:產生 gRPC 用戶端 stub 程式碼

如要搭配 TensorFlow Serving 使用 gRPC,您需要遵守 gRPC 工作流程。如要進一步瞭解相關細節,請參閱 gRPC 說明文件

a9d0e5cb543467b4.png

TensorFlow Serving 和 TensorFlow 會為您定義 .proto 檔案。自 TensorFlow 和 TensorFlow Serving 2.8 起,下列 .proto 檔案為必要檔案:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto

google/protobuf/any.proto
google/protobuf/wrappers.proto
  • 在終端機中,前往 starter/lib/proto/ 資料夾並產生 stub:
bash generate_grpc_stub_dart.sh

建立 gRPC 要求

與 REST 要求類似,您會在 gRPC 分支中建立 gRPC 要求。

if (_connectionMode == ConnectionModeType.rest) {

} else {
  // TODO: Create and send the gRPC request.

}
  • 新增這段程式碼以建立 gRPC 要求:
//Create the gRPC request.
final channel = ClientChannel(_server,
    port: grpcPort,
    options:
        const ChannelOptions(credentials: ChannelCredentials.insecure()));
_stub = PredictionServiceClient(channel,
    options: CallOptions(timeout: const Duration(seconds: 10)));

ModelSpec modelSpec = ModelSpec(
  name: 'spam-detection',
  signatureName: 'serving_default',
);

TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
TensorShapeProto_Dim inputDim =
    TensorShapeProto_Dim(size: Int64(maxSentenceLength));
TensorShapeProto inputTensorShape =
    TensorShapeProto(dim: [batchDim, inputDim]);
TensorProto inputTensor = TensorProto(
    dtype: DataType.DT_INT32,
    tensorShape: inputTensorShape,
    intVal: _tokenIndices);

// If you train your own model, update the input and output tensor names.
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
PredictRequest request = PredictRequest(
    modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});

注意:即使模型架構相同,輸入和輸出張量名稱也可以隨著模型而有所不同。訓練自己的模型時,請記得更新模型。

將 gRPC 要求傳送至 TensorFlow Serving

  • 將這段程式碼加到前一個程式碼片段後方,以將 gRPC 要求傳送至 TensorFlow Serving:
// Send the gRPC request.
PredictResponse response = await _stub.predict(request);

處理 TensorFlow Serving 中的 gRPC 回應

  • 將這段程式碼加到前一個程式碼片段後方,以執行回呼函式來處理回應:
// Process the response.
if (response.outputs.containsKey(outputTensorName)) {
  if (response.outputs[outputTensorName]!.floatVal[1] >
      classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  } else {
    return 'This sentence is not spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  }
} else {
  throw Exception('Error response');
}

現在,後處理程式碼會從回應中擷取分類結果,並將其顯示在使用者介面中。

執行

  1. 按一下 a19a0c68bc4046e6.png [開始偵錯],然後等待應用程式載入。
  2. 輸入一些文字,然後選取 [gRPC > Classify]

e44e6e9a5bde2188.png 92644d723f61968c.png

10. 恭喜

您使用 TensorFlow Serving 在應用程式中新增文字分類功能!

在下一個程式碼研究室中,您將強化模型,以便偵測目前的應用程式無法偵測到的特定垃圾郵件。

瞭解詳情