使用 TensorFlow Lite 模型製作工具訓練留言垃圾內容偵測模型

1. 事前準備

在本程式碼研究室中,您會查看使用 TensorFlow 和 TensorFlow Lite Model Maker 建立的程式碼,並以註解垃圾內容為基礎的資料集建立模型。原始資料可在 Kaggle 取得。並匯整為單一 CSV 檔案,移除損毀的文字、標記、重複的字詞等內容。這樣一來,您就能更專注於模型,而非文字。

您要檢查的程式碼會顯示在這裡,但我們強烈建議您在 Colaboratory 中一併查看程式碼

必要條件

課程內容

  • 如何使用 Colab 安裝 TensorFlow Lite Model Maker。
  • 如何將資料從 Colab 伺服器下載到裝置。
  • 如何使用資料載入器。
  • 如何建構模型。

軟硬體需求

2. 安裝 TensorFlow Lite 模型製作工具

  • 開啟這個 Colab。筆記本中的第一個儲存格會為您安裝 TensorFlow Lite Model Maker:
!pip install -q tflite-model-maker

完成後,請前往下一個儲存格。

3. 匯入程式碼

下一個儲存格包含筆記本中程式碼需要使用的匯入項目:

import numpy as np
import os
from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.text_classifier import DataLoader

import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')

這也會檢查您是否執行 TensorFlow 2.x,這是使用 Model Maker 的必要條件。

4. 下載資料

接著,您要將資料從 Colab 伺服器下載到裝置,並將 data_file 變數設為指向本機檔案:

data_file = tf.keras.utils.get_file(fname='comment-spam.csv', 
  origin='https://storage.googleapis.com/laurencemoroney-blog.appspot.com/lmblog_comments.csv', 
  extract=False)

Model Maker 可以使用這類簡單的 CSV 檔案訓練模型。您只需要指定哪些欄位包含文字,哪些欄位包含標籤,本程式碼研究室稍後會說明如何操作。

5. 預先學習的嵌入

一般來說,使用 Model Maker 時,您不會從頭開始建構模型,使用現有模型,並根據需求自訂。

這類語言模型會使用預先學習的嵌入。嵌入的背後概念是將字詞轉換為數字,整體語料庫中的每個字詞都會獲得一個數字。嵌入是向量,用於建立字詞的「方向」,藉此判斷字詞的情緒。舉例來說,在留言垃圾內容中經常使用的字詞,其向量會指向類似的方向;而不會出現在留言垃圾內容中的字詞,其向量則會指向相反的方向。

使用預先學習的嵌入時,您可以從已從大量文字中學習情緒的字詞語料庫或集合開始,因此比從零開始更快找到解決方案。

Model Maker 提供多種預先學習的嵌入內容,但最簡單快速的入門選項是 average_word_vec

程式碼如下:

spec = model_spec.get('average_word_vec')
spec.num_words = 2000
spec.seq_len = 20
spec.wordvec_dim = 7

num_words 參數

您也可以指定模型要使用的字數。

您可能會認為「越多越好」,但一般來說,每個字詞的使用頻率都有適當的次數。如果您使用整個語料庫中的每個字詞,模型可能會嘗試學習並建立只使用一次的字詞方向。在任何文字語料庫中,許多字詞只會使用一次或兩次,因此納入模型並不值得,因為這些字詞對整體情緒的影響微乎其微。

您可以根據所需字數,使用 num_words 參數調整模型。較小的數字可能會提供較小且較快的模型,但由於可辨識的字詞較少,準確度可能會降低。另一方面,較大的數字可能會提供較大且較慢的模型。找出最佳平衡點非常重要!

wordvec_dim 參數

wordved_dim 參數是您要用於每個字詞向量的維度數量。根據研究結果,經驗法則是字數的四次方根。舉例來說,如果使用 2,000 字,建議從 7 開始。如果變更使用的字數,也可以變更這項設定。

seq_len 參數

模型通常對輸入值非常嚴格,以語言模型來說,這表示語言模型可以分類特定靜態長度的句子。這取決於 seq_len 參數或序列長度

將字詞轉換為數字或符記後,句子就會變成這些符記的序列。在本例中,模型經過訓練,可分類及辨識含有 20 個權杖的句子。如果句子長度超過此限制,系統會截斷句子。如果較短,則會填補空白。您可以在用於此用途的語料庫中看到專屬的 <PAD> 權杖。

6. 使用資料載入器

您先前已下載 CSV 檔案。現在可以使用資料載入器,將這項資料轉換為模型可辨識的訓練資料:

data = DataLoader.from_csv(
    filename=data_file,
    text_column='commenttext',
    label_column='spam',
    model_spec=spec,
    delimiter=',',
    shuffle=True,
    is_training=True)

train_data, test_data = data.split(0.9)

如果您在編輯器中開啟 CSV 檔案,會發現每一行只有兩個值,而檔案第一行則以文字說明這些值。通常每個項目都會視為一欄。

您會看到第一欄的描述元是 commenttext,且每行的第一個項目都是留言文字。同樣地,第二欄的描述元是 spam,您會看到每行第二個項目是 TrueFalse,,表示該文字是否視為留言垃圾內容。其他屬性會設定您先前建立的 model_spec 變數,以及分隔符號字元 (在本例中為半形逗號,因為檔案是以半形逗號分隔)。您將使用這些資料訓練模型,因此 is_Training 會設為 True

您會想保留部分資料,用於測試模型。分割資料,其中 90% 用於訓練,另外 10% 用於測試/評估。由於我們要這麼做,因此請務必隨機選擇測試資料,而非資料集的「底部」10%,因此請在載入資料時使用 shuffle=True 隨機選擇。

7. 建構模型

下一個儲存格只是用來建構模型,而且只有一行程式碼:

# Build the model
model = text_classifier.create(train_data, model_spec=spec, epochs=50, 
                               validation_data=test_data)

這段程式碼會使用 Model Maker 建立文字分類器模型,並指定要使用的訓練資料 (如第四個步驟所述)、模型規格 (如第四個步驟所述) 和訓練週期數 (本例為 50)。

機器學習的基本原則是模式比對。一開始,系統會載入字詞的預先訓練權重,並嘗試將字詞分組,預測哪些字詞分組後會指出垃圾內容,哪些不會。第一次時,模型才剛開始運作,因此可能會平均分配流量。

c42755151d511ce.png

接著,模型會評估這個訓練週期的結果,並執行最佳化程式碼來調整預測,然後再次嘗試。這是紀元。因此,指定 epochs=50 時,系統會執行 50 次「迴圈」。

7d0ee06a5246b58d.png

當您達到第 50 個訓練週期時,模型回報的準確率會大幅提升。在本例中,系統會顯示 99%

驗證準確度通常會比訓練準確度略低,因為驗證準確度代表模型分類先前未見資料的準確度。這項作業會使用您先前保留的 10% 測試資料。

f063ff6e1d2add67.png

8. 匯出模型

  1. 執行這個儲存格,指定目錄並匯出模型:
model.export(export_dire='/mm_spam_savedmodel', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB, ExportFormat.SAVED_MODEL])
  1. 壓縮整個 /mm_spam_savedmodel 資料夾,然後下載產生的 mm_spam_savedmodel.zip 檔案,您會在下一個程式碼研究室中使用這個檔案。
# Rename the SavedModel subfolder to a version number
!mv /mm_spam_savedmodel/saved_model /mm_spam_savedmodel/123
!zip -r mm_spam_savedmodel.zip /mm_spam_savedmodel/

9. 恭喜

本程式碼研究室逐步說明如何使用 Python 程式碼建構及匯出模型。現在您有一個 SavedModel,以及結尾的標籤和詞彙。在下一個程式碼研究室中,您會瞭解如何使用這個模型,開始分類垃圾留言。

瞭解詳情