使用 TensorFlow Lite Model Maker 建立自訂文字分類模型

1. 事前準備

在這個程式碼研究室中,您將瞭解如何更新從原始 Blog-spam-comments 資料集建構的文字分類模型,但是以您自己的意見進行強化,以建立適用於您資料的模型。

必要條件

本程式碼研究室是「Flutter 應用程式途徑中使用文字分類」的一環。這個課程中的程式碼研究室是循序漸進的。您先前也應該在建構程式碼和應用程式的過程中,建置您原本就要建置的應用程式和模型。如果你還沒完成之前的活動,請立即停止吧!

課程內容

軟硬體需求

  • 您在先前活動中觀察到並建構的 Flutter 應用程式和垃圾郵件篩選器。

2. 加強文字分類

  1. 如要取得程式碼,只要複製這個存放區,然後從 tfserving-flutter/codelab2/finished 資料夾載入應用程式即可。
  2. 啟動 TensorFlow Serving Docker 映像檔後,請在您所建立的應用程式中輸入 buy my book to learn online trading,然後按一下 [gRPC > Classify]

8f1e1974522f274d.png

這個應用程式產生的垃圾內容分數偏低,因為原始資料集中的線上交易發生次數不多,因此模型還未意識到垃圾內容。在這個程式碼研究室中,您會將模型更新為新資料,讓模型識別出相同的垃圾郵件。

2bd68691a26aa3da.png

3. 編輯 CSV 檔案

為了訓練原始模型,我們建立了一個 CSV 格式 (lmblog_comments.csv) 的資料集,其中含有數千個標示為垃圾內容或不垃圾留言的留言。(如要檢查,請在文字編輯器中開啟 CSV 檔案)。

CSV 檔案的結構包含第一列描述資料欄,並分別標示為 commenttextspam。每個後續資料列都會遵循以下格式:

62025273971c9a7f.png

右邊的標籤會指定垃圾郵件的 true 值,而不是垃圾郵件的 false 值。例如,第三行視為垃圾內容。

如果有人在網路上收到關於線上交易的訊息,您可以在網站底部新增垃圾評論的例子。例如:

online trading can be highly highly effective,true
online trading can be highly effective,true
online trading now,true
online trading here,true
online trading for the win,true
  • 以新名稱 (例如 lmblog_comments.csv) 儲存檔案,以便用於訓練新的模型。

在這個程式碼研究室的後續部分,您將使用線上提供、修改和託管的範例,並搭配線上交易的更新。如要使用自己的資料集,可以變更程式碼中的網址。

4. 使用新資料重新訓練模型

如要重新訓練模型,只要重複使用 (SpamCommentsModelMaker.ipynb) 的程式碼,但將其指向名為「lmblog_comments_extras.csv」的新 CSV 資料集即可。如要更新含有完整內容的完整筆記本,您可以使用 SpamCommentsUpdateModelMaker.ipynb.

如果您有 Colaboratory 的存取權,可以直接啟動 Colaboratory。否則,請從存放區取得程式碼,然後在您選擇的筆記本環境中執行。

更新後的程式碼看起來像這樣:

training_data = tf.keras.utils.get_file(fname='comments-spam-extras.csv',   
          origin='https://storage.googleapis.com/laurencemoroney-blog.appspot.com/
                  lmblog_comments_extras.csv', 
          extract=False)

進行訓練時,您應該會發現模型仍會以高精確度進行訓練:

96a1547ddb6edf5b.png

壓縮 /mm_update_spam_savedmodel 整個資料夾,並減少系統產生的 mm_update_spam_savedmodel.zip 檔案。

# Rename the SavedModel subfolder to a version number
!mv /mm_update_spam_savedmodel/saved_model /mm_update_spam_savedmodel/123
!zip -r mm_update_spam_savedmodel.zip /mm_update_spam_savedmodel/

5. 啟動 Docker 並更新 Flutter 應用程式

  1. 將下載的 mm_update_spam_savedmodel.zip 檔案解壓縮到資料夾中,然後停止先前的程式碼研究室中的 Docker 容器執行個體,並重新啟動,但將 PATH/TO/UPDATE/SAVEDMODEL 預留位置改成代管檔案所在資料夾的絕對路徑:
docker run -it --rm -p 8500:8500 -p 8501:8501 -v "PATH/TO/UPDATE/SAVEDMODEL:/models/spam-detection" -e MODEL_NAME=spam-detection tensorflow/serving
  1. 使用您偏好的程式碼編輯器開啟 lib/main.dart 檔案,然後找出定義 inputTensorNameoutTensorName 變數的部分:
const inputTensorName = 'input_3';
const outputTensorName = 'dense_5';
  1. inputTensorName 變數重新指派給「input_1'」值,並將 outputTensorName 變數重新指派給 'dense_1' 值:
const inputTensorName = 'input_1';
const outputTensorName = 'dense_1';
  1. 複製您下載至 lib/assets/ 資料夾中的 vocab.txt 檔案,取代現有的檔案。
  2. 手動移除 Android 模擬器中的 Text Classification Flutter 應用程式。
  3. 在終端機執行 'flutter run' 指令以啟動應用程式。
  4. 在應用程式中輸入 buy my book to learn online trading,然後按一下 [gRPC > Classify]

現在我們改善了偵測模型,將「將我的書籍購買線上交易」視為垃圾內容。

6. 恭喜

使用新的資料重新訓練模型,同時將模型與 Flutter 應用程式進行整合,並更新了偵測新垃圾郵件功能的功能!

瞭解詳情