mp.tasks.text.TextClassifier

Class that performs classification on text.

This API expects a TFLite model with (optional) TFLite Model Metadata that contains the mandatory (described below) input tensors, output tensor, and the optional (but recommended) category labels as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for models with int32 input tensors because it contains the input process unit for the model's Tokenizer. No metadata is required for models with string input tensors.

(kTfLiteInt32)

  • 3 input tensors of size [batch_size x bert_max_seq_len] representing the input ids, segment ids, and mask ids
  • or 1 input tensor of size [batch_size x max_seq_len] representing the input ids or (kTfLiteString)
  • 1 input tensor that is shapeless or has shape [1] containing the input string

At least one output tensor with: (kTfLiteFloat32/kBool)

  • [1 x N] array with N represents the number of categories.
  • optional (but recommended) category labels as AssociatedFiles with type TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if any) is used to fill the category_name field of the results. The display_name field is filled from the AssociatedFile (if any) whose locale matches the display_names_locale field of the TextClassifierOptions used at creation time ("en" by default, i.e. English). If none of these are available, only the index field of the results will be filled.

graph_config The mediapipe text task graph config proto.

Methods

classify

View source

Performs classification on the input text.

Args
text The input text.

Returns
A TextClassifierResult object that contains a list of text classifications.

Raises
ValueError If any of the input arguments is invalid.
RuntimeError If text classification failed to run.

close

View source

Shuts down the mediapipe text task instance.

Raises
RuntimeError If the mediapipe text task failed to close.

create_from_model_path

View source

Creates an TextClassifier object from a TensorFlow Lite model and the default TextClassifierOptions.

Args
model_path Path to the model.

Returns
TextClassifier object that's created from the model file and the default TextClassifierOptions.

Raises
ValueError If failed to create TextClassifier object from the provided file such as invalid file path.
RuntimeError If other types of error occurred.

create_from_options

View source

Creates the TextClassifier object from text classifier options.

Args
options Options for the text classifier task.

Returns
TextClassifier object that's created from options.

Raises
ValueError If failed to create TextClassifier object from TextClassifierOptions such as missing the model.
RuntimeError If other types of error occurred.

__enter__

View source

Returns self upon entering the runtime context.

__exit__

View source

Shuts down the mediapipe text task instance on exit of the context manager.

Raises
RuntimeError If the mediapipe text task failed to close.