ディープラーニング入門 - TensorFlow を使ってモデルをトレーニングする方法を教えます

ディープラーニング入門 - TensorFlow を使ってモデルをトレーニングする方法を教えます

[[206688]]

導入

Tensorflow はバージョン 1.0 へのアップデート後に多くの新機能が追加され、tf フレームワーク (https://github.com/tensorflow/models) で記述された多くのディープ ネットワーク構造がリリースされ、開発の難易度が大幅に軽減されました。既製のネットワーク構造を使用すると、微調整と再トレーニングの両方がはるかに便利になります。最近、ついに TensorFlow Object Detection API の ssd_mobilenet_v1 モデルを実行しました。ここでは、データの準備からモデルの使用までのプロセス全体を実行する方法を記録します。これは、私やクラスメートにとって役立つと思います。

オブジェクト検出 API は、5 つのネットワーク構造に対して事前トレーニング済みの重みを提供します。これらはすべて、COCO データセットを使用してトレーニングされています。5 つのモデルは、SSD+mobilenet、SSD+inception_v2、R-FCN+resnet101、faster RCNN+resnet101、faster RCNN+inception+resnet101 です。各モデルの精度と計算時間は以下のとおりです。以下では、オブジェクト検出を使用して独自のモデルをトレーニングする方法について説明します。

ここでは TensorFlow のインストールについては説明しません。インターネット上には多くのチュートリアルがあり、TensorFlow のインストールに関する非常に詳細なドキュメントを見つけることができます。

トレーニング前の準備:

protobuf を使用してモデルとトレーニング パラメータを設定するため、API を正常に使用する前に protobuf ライブラリをコンパイルする必要があります。直接コンパイルされた pb ライブラリをここ (https://github.com/google/protobuf/releases) からダウンロードし、圧縮パッケージを解凍して、環境変数に protoc を追加できます。

  1. $ cd tensorflow/モデル
  2.  
  3. $ protoc object_detection/protos/*.proto --python_out=.  

(環境変数に protoc を追加したところ、*.proto ファイルが見つからないというエラーが発生しました。その後、protoc.exe を models/object_detection ディレクトリに配置して再実行しました。)

次に、モデルとスリム (tf アドバンス フレームワーク) を Python 環境変数に追加します。

  1. PYTHONPATH=$PYTHONPATH:/your/path/から/tensorflow/models:/your/path/から/tensorflow/models/slim

データ準備:

データセットは PASCAL VOC 構造に変換する必要があります。API は、VOC 構造データセットを .record 形式に変換するための create_pascal_tf_record.py を提供します。しかし、より簡単な方法を見つけました。Datitran は、.record 形式を生成するより簡単な方法を提供します。

まず、画像に適切なラベルを付ける必要があります。ここでは labelImg ツールを使用できます。サンプルに注釈が付けられるたびに、XML 注釈ファイルが生成されます。次に、これらの注釈付き XML ファイルを、トレーニング セットと検証セットに応じてそれぞれ 2 つのディレクトリに配置します。xml_to_csv.py スクリプトは Datitran で提供されています。ここでは、マークされたディレクトリ名のみを指定する必要があります。次に、対応する csv 形式を .record 形式に変換する必要があります。

  1. main() を定義します:
  2. # image_path = os.path.join (os.getcwd(), 'annotations' ) です
  3. image_path = r 'D:\training-sets\object-detection\sunglasses\label\test'  
  4. xml_df = xml_to_csv(画像パス)
  5. xml_df.to_csv( 'sunglasses_test_labels.csv' インデックス=なし)
  6. print( 'xml を csv に正常に変換しました。' )

generate_tfrecord.py を呼び出し、2 つのパラメーター –csv_input と –output_path を必ず指定してください。次のコマンドを実行します。

  1. Python generate_tfrecord.py --csv_input=sunglasses_test_labels.csv --output_path=sunglass_test.record  

これにより、トレーニングと検証用の train.record と test.record が生成されます。次にラベル名を指定します。model/object_detection/data/pet_label_map.pbtxtに従ってファイルを再作成し、ラベル名を指定します。

  1. アイテム {
  2. id: 1
  3. 名前 「サングラス」  
  4. }

電車:

必要に応じて、coco データセットで事前トレーニングされたモデルを選択し、トレーニングするディレクトリにプレフィックス model.ckpt を配置します。ここで、メタファイルはグラフとメタデータを保存し、ckpt はネットワークの重みを保存します。これらのファイルは、事前トレーニング済みモデルの初期状態を表します。

ssd_mobilenet_v1_pets.config ファイルを開き、次の変更を加えます。

num_classes: 独自のクラス番号に変更する

すべての PATH_TO_BE_CONFIGURED の場所を、以前に設定したパスに変更します (合計 5 つ)

その他のパラメータはすべてデフォルトのままです。

上記のファイルを準備したら、トレーニング用のトレーニングファイルを直接呼び出すことができます。

  1. Python オブジェクト検出/train.py \
  2. --logtostderr \  
  3. --pipeline_config_path= D:/training-sets /data-translate/training/ssd_mobilenet_v1_pets.config \  
  4. --train_dir=D:/トレーニングセット/データ変換/トレーニング 

TensorBoard モニタリング:

トレーニングプロセスは、Tensorboard ツールを使用して監視できます。次のコマンドを入力した後、ブラウザに localhost:6006 (デフォルト) を入力します。

  1. テンソルボード--logdir= D:/training-sets/data-translate/training  

インジケーター曲線やモデル ネットワーク アーキテクチャも多数あります。著者はまだ多くのインジケーターの意味を理解していませんが、TensorBoard ツールは非常に強力であるはずだと感じています。ただし、Total_Loss を使用すると、全体的なトレーニング状況を確認できます。

全体的に、損失曲線は確かに収束しており、全体的なトレーニング効果は満足のいくものです。さらに、TensorFlow では、トレーニング中に精度を検証するために検証セットを使用する機能も提供されていますが、著者はそれを呼び出すときにまだいくつかの問題に遭遇しました。これについては、ここでは今のところ詳しく説明しません。

フリーズモデルのエクスポート:

モデルの実際の効果を確認する前に、トレーニング プロセス ファイルをエクスポートし、.pb モデル ファイルを生成する必要があります。本来、tensorflow/python/tools/freeze_graph.py はモデルをフリーズするための API を提供していますが、出力の最終ノード名(通常は softmax などの最終層の活性化関数の名前)を提供する必要があります。物体検出 API は事前学習済みのネットワークを提供しており、最終ノード名が簡単には見つからないため、object_detection ディレクトリに export_inference_graph.py も提供しています。

  1. python export_inference_graph.py \
  2. --input_type 画像テンソル 
  3. --pipeline_config_path D:/training-sets /data-translate/training/ssd_mobilenet_v1_pets.config \  
  4. --trained_checkpoint_prefix D:/training-sets /data-translate/training/ssd_mobilenet_v1_pets.config /model.ckpt-* \  
  5. --出力ディレクトリ D:/トレーニングセット/データ変換/トレーニング/結果 

エクスポートが完了すると、output_directory の下に、frozen_inference_graph.pb、model.ckpt.data-00000-of-00001、model.ckpt.meta、および model.ckpt.data ファイルが生成されます。

生成されたモデルを呼び出します。

ディレクトリ自体に呼び出し例がありますが、次のように少し変更されています。

  1. cv2をインポート
  2. numpyをnpとしてインポートする
  3. テンソルフローをtfとしてインポートする
  4. object_detection.utilsからlabel_map_util をインポートします
  5. object_detection.utilsからvisualization_utils をvis_utilとしてインポートします
  6.  
  7.  
  8. クラスTOD(オブジェクト):
  9. __init__(self)を定義します。
  10. self.PATH_TO_CKPT = r 'D:\lib\tf-model\models-master\object_detection\training\frozen_inference_graph.pb'  
  11. self.PATH_TO_LABELS = r 'D:\lib\tf-model\models-master\object_detection\training\sunglasses_label_map.pbtxt'  
  12. 自己.NUM_CLASSES = 1
  13. self.detection_graph = self._load_model()
  14. self.category_index = self._load_label_map()
  15.  
  16. def _load_model(自己):
  17. 検出グラフ = tf.Graph()
  18. detection_graph.as_default()を使用する場合:
  19. od_graph_def = tf.GraphDef()
  20. tf.gfile.GFile(self.PATH_TO_CKPT, 'rb' )fidとして使用します:
  21. シリアル化されたグラフ = fid.read ()
  22. od_graph_def.ParseFromString(シリアル化されたグラフ)
  23. tf.import_graph_def(od_graph_def、名前= '' )
  24. 検出グラフを返す
  25.  
  26. _load_label_map(self)を定義します。
  27. label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
  28. カテゴリ = label_map_util.convert_label_map_to_categories(label_map,
  29. max_num_classes=self.NUM_CLASSES、
  30. use_display_name = True )
  31. category_index = label_map_util.create_category_index(カテゴリ)
  32. カテゴリインデックスを返す
  33.  
  34. defdetect(自己、画像):
  35. self.detection_graph.as_default()を使用する場合:
  36. tf.Session(graph=self.detection_graph)sessとして:
  37. # モデルは画像の形状[1, None, None, 3] であると想定しているため、次元を拡張します。
  38. image_np_expanded = np.expand_dims(画像、軸=0)
  39. image_tensor = self.detection_graph.get_tensor_by_name( 'image_tensor:0' )
  40. ボックス = self.detection_graph.get_tensor_by_name( 'detection_boxes:0' )
  41. スコア = self.detection_graph.get_tensor_by_name( 'detection_scores:0' )
  42. クラス = self.detection_graph.get_tensor_by_name( 'detection_classes:0' )
  43. num_detections = self.detection_graph.get_tensor_by_name( 'num_detections:0' )
  44. # 実際の検出。
  45. (ボックス、スコア、クラス、検出数) = sess.run(
  46. [ボックス、スコア、クラス、検出数]、
  47. feed_dict={image_tensor: image_np_expanded})
  48. #検出結果視覚化
  49. vis_util.visualize_boxes_and_labels_on_image_array(
  50. 画像、
  51. np.squeeze(ボックス),
  52. np.squeeze(クラス).astype(np.int32)、
  53. np.squeeze(スコア),
  54. 自己.カテゴリインデックス、
  55. use_normalized_coordinates = True
  56. 線の太さ=8)
  57.  
  58. cv2.namedWindow( "検出" , cv2.WINDOW_NORMAL)
  59. cv2.imshow( "検出" , 画像)
  60. cv2.waitKey(0)
  61.  
  62. __name__ == '__main__'の場合:
  63. 画像 = cv2.imread( 'image.jpg' )
  64. 検出 = TOD()
  65. detecotr.detect(画像)

認識効果の写真をいくつか紹介します。

終わり。

<<:  研究によると、AIはより多くの雇用を生み出している

>>:  中国チームがボストン・ダイナミクスに対抗する四足歩行ロボットを発表

ブログ    

推薦する

ピュー研究所の報告:2025年までにAIのせいで7500万人が解雇される

[[253650]]テクノロジー専門家の約 37% は、人工知能 (AI) と関連技術の進歩により、...

MIT は Google と提携して 7 台のマルチタスク ロボットをトレーニングし、9,600 のタスクで 89% の成功率を達成しました。

タスクの数が増えるにつれて、現在の計算方法を使用して汎用の日常的なロボットを構築するコストは法外なも...

世界中で人気のGPT-3がなぜ人々の仕事を破壊しているのか?

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

...

Transformer ニューラル ネットワーク モデルを 1 つの記事で理解する

こんにちは、皆さん。私は Luga です。今日は、人工知能 (AI) エコシステムに関連するテクノロ...

1 つの記事でニューラル ネットワークを理解する

[51CTO.com からのオリジナル記事]人工知能は近年非常に人気の高い技術です。99 歳から歩け...

...

ChatGPTが使用する機械学習技術

著者 |ブライト・リャオ「プログラマーから見たChatGPT」の記事では、開発者のChatGPTに対...

...

...

百度がナレッジグラフをひっそりとリリース、次世代検索エンジンのプロトタイプを公開

一部のネットユーザーが「Crazy Guess the Idiom」ゲームを解読する最新の戦略を明ら...

Aurora の 1 億ドルの買収の背後にあるもの: RISC-V の創始者が「中国製チップ」を開発するという野望

2月27日、米国の著名な自動運転企業であるAuroraは、ライダーチップ企業OURSを1億ドルで買収...

美団のドローンの暴露:インターネットはインターネットに別れを告げる

美団ドローンは、ドローンそのもの以上のものを見せてくれるだけでなく、インターネットがインターネットに...

JDロジスティクスは知能を高めつつ、宅配業者から仕事を奪っている

JD.comは早くも2017年8月に、陝西省の地域をカバーする中国初のドローン空域の承認を取得しまし...

...