mediapipe_model_maker.object_detector.ObjectDetector

ObjectDetector for building object detection model.

model_spec Specifications for the model.
label_names A list of label names for the classes.
hparams The hyperparameters for training object detector.
model_options Options for creating the object detector model.

Methods

create

View source

Creates and trains an ObjectDetector.

Loads data and trains the model based on data for object detection.

Args
train_data Training data.
validation_data Validation data.
options Configurations for creating and training object detector.

Returns
An instance of ObjectDetector.

evaluate

View source

Overrides Classifier.evaluate to calculate COCO metrics.

export_labels

View source

Exports classification labels into a label file.

Args
export_dir The directory to save exported files.
label_filename File name to save labels model. The full export path is {export_dir}/{label_filename}.

export_model

View source

Converts and saves the model to a TFLite file with metadata included.

The model export format is automatically set based on whether or not quantization_aware_training(QAT) was run. The model exports to float32 by default and will export to an int8 quantized model if QAT was run. To export a float32 model after running QAT, run restore_float_ckpt before this method. For custom post-training quantization without QAT, use the quantization_config parameter.

Note that only the TFLite file is needed for deployment. This function also saves a metadata.json file to the same directory as the TFLite file which can be used to interpret the metadata content in the TFLite file.

Args
model_name File name to save TFLite model with metadata. The full export path is {self._hparams.export_dir}/{model_name}.
quantization_config The configuration for model quantization. Note that int8 quantization aware training is automatically applied when possible. This parameter is used to specify other post-training quantization options such as fp16 and int8 without QAT.

Raises
ValueError If a custom quantization_config is specified when the model has quantization aware training enabled.

export_tflite

View source

Converts the model to requested formats.

Args
export_dir The directory to save exported files.
tflite_filename File name to save TFLite model. The full export path is {export_dir}/{tflite_filename}.
quantization_config The configuration for model quantization.
preprocess A callable to preprocess the representative dataset for quantization. The callable takes three arguments in order: feature, label, and is_training.

quantization_aware_training

View source

Runs quantization aware training(QAT) on the model.

The QAT step happens after training a regular float model from the create method. This additional step will fine-tune the model with a lower precision in order mimic the behavior of a quantized model. The resulting quantized model generally has better performance than a model which is quantized without running QAT. See the following link for more information:

Just like training the float model using the create method, the QAT step also requires some manual tuning of hyperparameters. In order to run QAT more than once for purposes such as hyperparameter tuning, use the restore_float_ckpt method to restore the model state to the trained float checkpoint without having to rerun the create method.

Args
train_data Training dataset.
validation_data Validaiton dataset.
qat_hparams Configuration for QAT.

restore_float_ckpt

View source

Loads a float checkpoint of the model from {hparams.export_dir}/float_ckpt.

The float checkpoint at {hparams.export_dir}/float_ckpt is automatically saved after training an ObjectDetector using the create method. This method is used to restore the trained float checkpoint state of the model in order to run quantization_aware_training multiple times. Example usage:

Train a model

model = object_detector.create(...)

Run QAT

model.quantization_aware_training(...) model.evaluate(...)

Restore the float checkpoint to run QAT again

model.restore_float_ckpt()

Run QAT with different parameters

model.quantization_aware_training(...) model.evaluate(...)

summary

View source

Prints a summary of the model.