mediapipe_model_maker.text_classifier.TextClassifier

API for creating and training a text classification model.

model_spec Specification for the model.
label_names A list of label names for the classes.
shuffle Whether the dataset should be shuffled.

Methods

create

View source

Factory function that creates and trains a text classifier.

Note that train_data and validation_data are expected to share the same label_names since they should be split from the same dataset.

Args
train_data Training data.
validation_data Validation data.
options Options for creating and training the text classifier.

Returns
A text classifier.

Raises
ValueError if train_data and validation_data do not have the same label_names or options contains an unknown supported_model

evaluate

View source

Overrides Classifier.evaluate().

Args
data Evaluation dataset. Must be a TextClassifier Dataset.
batch_size Number of samples per evaluation step.
desired_precisions If specified, adds a RecallAtPrecision metric per desired_precisions[i] entry which tracks the recall given the constraint on precision. Only supported for binary classification.
desired_recalls If specified, adds a PrecisionAtRecall metric per desired_recalls[i] entry which tracks the precision given the constraint on recall. Only supported for binary classification.

Returns
The loss value and accuracy.

Raises
ValueError if data is not a TextClassifier Dataset.

export_labels

View source

Exports classification labels into a label file.

Args
export_dir The directory to save exported files.
label_filename File name to save labels model. The full export path is {export_dir}/{label_filename}.

export_model

View source

Converts and saves the model to a TFLite file with metadata included.

Note that only the TFLite file is needed for deployment. This function also saves a metadata.json file to the same directory as the TFLite file which can be used to interpret the metadata content in the TFLite file.

Args
model_name File name to save TFLite model with metadata. The full export path is {self._hparams.export_dir}/{model_name}.
quantization_config The configuration for model quantization.

export_tflite

View source

Converts the model to requested formats.

Args
export_dir The directory to save exported files.
tflite_filename File name to save TFLite model. The full export path is {export_dir}/{tflite_filename}.
quantization_config The configuration for model quantization.
preprocess A callable to preprocess the representative dataset for quantization. The callable takes three arguments in order: feature, label, and is_training.

load_bert_classifier

View source

save_model

View source

Saves the model in SavedModel format.

For more information, see https://www.tensorflow.org/guide/saved_model

Args
model_name Name of the saved model.

summary

View source

Prints a summary of the model.