Create a Flutter app to classify texts

1. Before you begin

In this codelab, you learn how to run a text-classification inference from a Flutter app with TensorFlow Serving through REST and gRPC.

Prerequisites

What you'll learn

  • How to build a simple Flutter app and classify texts through TensorFlow Serving (REST and gRPC).
  • How to display the results in the UI.

What you'll need

2. Get set up

To download the code for this codelab:

  1. Navigate to the GitHub repository for this codelab.
  2. Click Code > Download zip to download all the code for this codelab.

2cd45599f51fb8a2.png

  1. Unzip the downloaded zip file to unpack a codelabs-master root folder with all of the resources that you need.

For this codelab, you only need the files in the tfserving-flutter/codelab2 subdirectory in the repository, which contains two folders:

  • The starter folder contains the starter code that you build upon for this codelab.
  • The finished folder contains the completed code for the finished sample app.

3. Download the dependencies for the project

  1. In VS Code, click File > Open folder and then select the starter folder from the source code that you downloaded earlier.
  2. If you see a dialog appear that prompts you to download the required packages for the starter app, click Get packages.
  3. If you don't see this dialog, open your terminal and then run flutter pub get command in the starter folder.

7ada07c300f166a6.png

4. Run the starter app

  1. In VS Code, ensure that the Android Emulator or iOS Simulator is properly set up and appears in the status bar.

For example, here's what you see when you use Pixel 5 with the Android Emulator:

9767649231898791.png

Here's what you see when you use iPhone 13 with the iOS Simulator:

95529e3a682268b2.png

  1. Click a19a0c68bc4046e6.png Start debugging.

Run and explore the app

The app should launch on your Android Emulator or iOS Simulator. The UI is pretty straightforward. There's a text field that lets the user type in the text. The user can choose whether to send the data to the backend with REST or gRPC. The backend uses a TensorFlow model to perform text classification on the preprocessed input and returns the classification result to the client app, which updates the UI in turn.

b298f605d64dc132.png d3ef3ccd3c338108.png

If you click Classify, nothing happens because it can't communicate with the backend yet.

5. Deploy a text-classification model with TensorFlow Serving

Text classification is a very common machine learning task that classifies texts into predefined categories. In this codelab, you deploy the pretrained model from the Train a comment-spam detection model with TensorFlow Lite Model Maker codelab with TensorFlow Serving and call the backend from your Flutter frontend to classify the input text as spam or not spam.

Start TensorFlow Serving

  • In your terminal, start TensorFlow Serving with Docker, but replace the PATH/TO/SAVEDMODEL placeholder with the absolute path of the mm_spam_savedmodel folder on your computer.
docker pull tensorflow/serving

docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving

Docker automatically downloads the TensorFlow Serving image first, which takes a minute. Afterward, TensorFlow Serving should start. The log should look like this code snippet:

2022-02-25 06:01:12.513231: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2022-02-25 06:01:12.585012: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 3000000000 Hz
2022-02-25 06:01:13.395083: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /models/ssd_mobilenet_v2_2/123
2022-02-25 06:01:13.837562: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 1928700 microseconds.
2022-02-25 06:01:13.877848: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /models/ssd_mobilenet_v2_2/123/assets.extra/tf_serving_warmup_requests
2022-02-25 06:01:13.929844: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: spam-detection version: 123}
2022-02-25 06:01:13.985848: I tensorflow_serving/model_servers/server_core.cc:486] Finished adding/updating models
2022-02-25 06:01:13.985987: I tensorflow_serving/model_servers/server.cc:367] Profiler service is enabled
2022-02-25 06:01:13.988994: I tensorflow_serving/model_servers/server.cc:393] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
2022-02-25 06:01:14.033872: I tensorflow_serving/model_servers/server.cc:414] Exporting HTTP/REST API at:localhost:8501 ...
[evhttp_server.cc : 245] NET_LOG: Entering the event loop ...

6. Tokenize input sentence

The backend is ready now, so you're almost ready to send client requests to TensorFlow Serving, but first you need to tokenize the input sentence. If you inspect the input tensor of the model, you can see that it expects a list of 20 integer numbers instead of raw strings. Tokenization is when you map the individual words you type in the app to a list of integers based on a vocabulary dictionary before you send them to the backend for classification. For example, if you type buy book online to learn more, the tokenization process maps it to [32, 79, 183, 10, 224, 631, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]. The specific numbers can vary based on the vocabulary dictionary.

  1. In the lib/main.dart file, add this code to the predict() method to build the _vocabMap vocabulary dictionary.
// Build _vocabMap if empty.
if (_vocabMap.isEmpty) {
  final vocabFileString = await rootBundle.loadString(vocabFile);
  final lines = vocabFileString.split('\n');
  for (final l in lines) {
    if (l != "") {
      var wordAndIndex = l.split(' ');
      (_vocabMap)[wordAndIndex[0]] = int.parse(wordAndIndex[1]);
    }
  }
} 
  1. Immediately after the previous code snippet, add this code to implement tokenization:
// Tokenize the input sentence.
final inputWords = _inputSentenceController.text
    .toLowerCase()
    .replaceAll(RegExp('[^a-z ]'), '')
    .split(' ');
// Initialize with padding token.
_tokenIndices = List.filled(maxSentenceLength, 0);
var i = 0;
for (final w in inputWords) {
  if ((_vocabMap).containsKey(w)) {
    _tokenIndices[i] = (_vocabMap)[w]!;
    i++;
  }

  // Truncate the string if longer than maxSentenceLength.
  if (i >= maxSentenceLength - 1) {
    break;
  }
}

This code lowercases the sentence string, removes non-alphabet characters, and maps the words to 20 integer indices based on the vocabulary table.

7. Connect the Flutter app with TensorFlow Serving through REST

There are two ways to send requests to TensorFlow Serving:

  • REST
  • gRPC

Send requests and receive responses through REST

There are three simple steps to send requests and receive responses through REST:

  1. Create the REST request.
  2. Send the REST request to TensorFlow Serving.
  3. Extract the predicted result from the REST response and render the UI.

You complete these steps in the main.dart file.

Create and send the REST request to TensorFlow Serving

  1. Right now, the predict() function doesn't send the REST request to TensorFlow Serving. You need to implement the REST branch to create a REST request:
if (_connectionMode == ConnectionModeType.rest) {
  // TODO: Create and send the REST request.

}
  1. Add this code to the REST branch:
//Create the REST request.
final response = await http.post(
  Uri.parse('http://' +
      _server +
      ':' +
      restPort.toString() +
      '/v1/models/' +
      modelName +
      ':predict'),
  body: jsonEncode(<String, List<List<int>>>{
    'instances': [_tokenIndices],
  }),
);

Process the REST response from TensorFlow Serving

  • Add this code right after the previous code snippet to handle the REST response:
// Process the REST response.
if (response.statusCode == 200) {
  Map<String, dynamic> result = jsonDecode(response.body);
  if (result['predictions']![0][1] >= classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        result['predictions']![0][1].toString();
  }
  return 'This sentence is not spam. Spam score is ' +
      result['predictions']![0][1].toString();
} else {
  throw Exception('Error response');
}

The postprocessing code extracts the probability that the input sentence is a spam message from the response and displays the classification result in the UI.

Run it

  1. Click a19a0c68bc4046e6.png Start debugging and then wait for the app to load.
  2. Enter some text and then select REST > Classify.

8e21d795af36d07a.png e79a0367a03c2169.png

8. Connect the Flutter app with TensorFlow Serving through gRPC

In addition to REST, TensorFlow Serving also supports gRPC.

b6f4449c2c850b0e.png

gRPC is a modern, open source, high-performance Remote Procedure Call (RPC) framework that can run in any environment. It can efficiently connect services in, and across, data centers with pluggable support for load balancing, tracing, health checking, and authentication. It's been observed that gRPC is more performant than REST in practice.

Send requests and receive responses with gRPC

There are four simple steps to send requests and receive responses with gRPC:

  1. Optional: Generate the gRPC client stub code.
  2. Create the gRPC request.
  3. Send the gRPC request to TensorFlow Serving.
  4. Extract the predicted result from the gRPC response and render the UI.

You complete these steps in the main.dart file.

Optional: Generate the gRPC client stub code

To use gRPC with TensorFlow Serving, you need to follow the gRPC workflow. To learn more about the details, see the gRPC documentation.

a9d0e5cb543467b4.png

TensorFlow Serving and TensorFlow define the .proto files for you. As of TensorFlow and TensorFlow Serving 2.8, these .proto files are the ones needed:

tensorflow/core/example/example.proto
tensorflow/core/example/feature.proto
tensorflow/core/protobuf/struct.proto
tensorflow/core/protobuf/saved_object_graph.proto
tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/trackable_object_graph.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/attr_value.proto
tensorflow/core/framework/function.proto
tensorflow/core/framework/types.proto
tensorflow/core/framework/tensor_shape.proto
tensorflow/core/framework/full_type.proto
tensorflow/core/framework/versions.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/graph.proto
tensorflow/core/framework/tensor.proto
tensorflow/core/framework/resource_handle.proto
tensorflow/core/framework/variable.proto

tensorflow_serving/apis/inference.proto
tensorflow_serving/apis/classification.proto
tensorflow_serving/apis/predict.proto
tensorflow_serving/apis/regression.proto
tensorflow_serving/apis/get_model_metadata.proto
tensorflow_serving/apis/input.proto
tensorflow_serving/apis/prediction_service.proto
tensorflow_serving/apis/model.proto

google/protobuf/any.proto
google/protobuf/wrappers.proto
  • In your terminal, navigate to the starter/lib/proto/ folder and generate the stub:
bash generate_grpc_stub_dart.sh

Create the gRPC request

Similar to the REST request, you create the gRPC request in the gRPC branch.

if (_connectionMode == ConnectionModeType.rest) {

} else {
  // TODO: Create and send the gRPC request.

}
  • Add this code to create the gRPC request:
//Create the gRPC request.
final channel = ClientChannel(_server,
    port: grpcPort,
    options:
        const ChannelOptions(credentials: ChannelCredentials.insecure()));
_stub = PredictionServiceClient(channel,
    options: CallOptions(timeout: const Duration(seconds: 10)));

ModelSpec modelSpec = ModelSpec(
  name: 'spam-detection',
  signatureName: 'serving_default',
);

TensorShapeProto_Dim batchDim = TensorShapeProto_Dim(size: Int64(1));
TensorShapeProto_Dim inputDim =
    TensorShapeProto_Dim(size: Int64(maxSentenceLength));
TensorShapeProto inputTensorShape =
    TensorShapeProto(dim: [batchDim, inputDim]);
TensorProto inputTensor = TensorProto(
    dtype: DataType.DT_INT32,
    tensorShape: inputTensorShape,
    intVal: _tokenIndices);

// If you train your own model, update the input and output tensor names.
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
PredictRequest request = PredictRequest(
    modelSpec: modelSpec, inputs: {inputTensorName: inputTensor});

Note: The input and output tensor names could differ from model to model, even if the model architectures are the same. Make sure to update them if you train your own model.

Send the gRPC request to TensorFlow Serving

  • Add this code after the previous code snippet to send the gRPC request to TensorFlow Serving:
// Send the gRPC request.
PredictResponse response = await _stub.predict(request);

Process the gRPC response from TensorFlow Serving

  • Add this code after the previous code snippet to implement the callback functions to handle the response:
// Process the response.
if (response.outputs.containsKey(outputTensorName)) {
  if (response.outputs[outputTensorName]!.floatVal[1] >
      classificationThreshold) {
    return 'This sentence is spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  } else {
    return 'This sentence is not spam. Spam score is ' +
        response.outputs[outputTensorName]!.floatVal[1].toString();
  }
} else {
  throw Exception('Error response');
}

Now the postprocessing code extracts the classification result from the response and displays it in the UI.

Run it

  1. Click a19a0c68bc4046e6.png Start debugging and then wait for the app to load.
  2. Enter some text and then select gRPC > Classify.

e44e6e9a5bde2188.png 92644d723f61968c.png

9. Congratulations

You used TensorFlow Serving to add text-classification capabilities to your app!

In the next codelab, you will enhance the model so that you can detect specific spam messages that cannot be detected by the current app.

Learn more