TensorFlow でトレーニングしたモデルを保存および復元する方法

TensorFlow でトレーニングしたモデルを保存および復元する方法

ディープ ニューラル ネットワーク モデルの複雑さが非常に高い場合、保有するデータの量、モデルを実行しているハードウェアなどに応じて、トレーニングにかなりの時間がかかることがあります。ほとんどの場合、中断した(またはミスをした)場合でもミスなく中断したところから続行できるように、ファイルを保存して実験の安定性を確保する必要があります。

さらに重要なのは、TensorFlow のようなディープラーニング フレームワークでは、トレーニングが成功した後、モデルの学習したパラメータを再利用して新しいデータに対する予測を行う必要があることです。

[[208282]]

この記事では、TensorFlow モデルを保存および復元する方法について説明します。最も便利な方法を紹介し、いくつかの例を示します。

1. まずTensorFlowモデルを簡単に紹介します

TensorFlow の主な機能は、NumPy の多次元配列に似たテンソルを通じて基本的なデータ構造を伝達し、グラフがデータ計算を表すことです。これはシンボリック ライブラリであり、グラフとテンソルを定義するとモデルのみが作成され、具体的な値の取得とテンソルの操作はセッションで実行され、グラフ内でモデリング操作を実行するためのメカニズムです。セッションが閉じられるとテンソルの具体的な値はすべて失われます。これが、セッションの実行後にモデルをファイルに保存するもう 1 つの理由です。

例を見たほうが理解しやすいので、2 次元データの線形回帰用の簡単な TensorFlow モデルを作成しましょう。

まず、ライブラリをインポートします。

  1. テンソルフローをtfとしてインポートする
  2. numpyをnpとしてインポートする
  3. matplotlib.pyplot を plt としてインポートします。
  4. %matplotlib インライン

次のステップはモデルを作成することです。次の形式で、二次関数の水平方向と垂直方向の変位を推定するモデルを生成します。

  1. y = (x - h) ^ 2 + v

ここで、h は水平方向の変化、v は垂直方向の変化です。

モデルの生成方法は次のとおりです (詳細についてはコード内のコメントを参照してください)。

  1. # 変数の重複を避けるために、実行ごとに現在のグラフをクリアします
  2. tf.reset_default_graph()
  3. # x と y の点のプレースホルダーを作成する
  4. X = tf .placeholder("float")
  5. Y = tf .placeholder("float")
  6. # 学習する必要がある2つのパラメータを初期化します
  7. h_est = tf .Variable(0.0,名前= 'hor_estimate' )
  8. v_est = tf .Variable(0.0,名前= 'ver_estimate' )
  9. # y_estはy軸上の推定値を保持します
  10. y_est = tf .square(X - h_est) + v_est
  11. # コスト関数をYとy_estの間の距離の2乗として定義します
  12. コスト= (tf.pow(Y - y_est, 2))
  13. # コスト関数を最小化するためのトレーニング操作。
  14. # 学習率は0.001です
  15. trainop = tf.train.GradientDescentOptimizer (0.001).minimize(コスト)

モデルを作成するプロセスでは、セッションでモデルを実行し、実際のデータを渡す必要があります。いくつかの二次データを生成し、それにノイズを追加します。

  1. # 水平および垂直シフトにいくつかの値を使用する
  2. h = 1    
  3. v = -2
  4. # ノイズを含むトレーニングデータを生成する
  5. x_train = np.linspace (-2,4,201)
  6. ノイズ= np .random.randn(*x_train.shape) * 0.4
  7. y_train = (x_train - h) ** 2 + v + ノイズ
  8. # データを視覚化する
  9. plt.rcParams['figure.figsize'] = (10, 6)
  10. plt.scatter(x_train, y_train)
  11. plt.xlabel('x_train')
  12. plt.ylabel('y_train')

2. セーバークラス

Saver クラスは TensorFlow ライブラリによって提供されるクラスです。グラフ構造と変数を保存するのに最適な方法です。

(1)モデルを保存する

次のコード行では、Saver オブジェクトを定義し、train_graph() 関数で 100 回の反復にわたってコスト関数を最小化します。次に、各反復と最適化が完了したら、モデルをディスクに保存します。ディスク上に作成される各保存は、「チェックポイント」と呼ばれるバイナリ ファイルと呼ばれます。

  1. # Saverオブジェクトを作成する
  2. セーバー= tf.train.Saver ()
  3.  
  4. init = tf .global_variables_initializer()
  5.  
  6. # セッションを実行します。コストを最小限に抑えるために100回の反復を実行します。
  7. train_graph()を定義します:
  8. tf.Session() を sess として使用:
  9. セッションの実行(初期化)
  10. iが範囲(100)内にある場合:
  11. zip(x_train, y_train)内の(x, y)の場合:
  12.  
  13. # 実際のデータを列車運行に供給する
  14. sess.run(trainop、 feed_dict ={X: x、Y: y})を実行します。
  15.  
  16. # 繰り返しごとにチェックポイントを作成する
  17. saver.save(sess, 'model_iter', global_step = i ) は、
  18.  
  19. # 最終モデルを保存する
  20. セーバー.save(sess, 'model_final')
  21. h_ =セッション実行(h_est)
  22. v_ =セッション実行(v_est)
  23. h_、v_を返す

それでは、上記の関数を使用してモデルをトレーニングし、トレーニングされたパラメータを出力してみましょう。

  1. 結果= train_graph ()
  2. print(" h_est = %.2f, v_est = %.2f" % 結果)
  3.  
  4. $ python tf_save.py
  5. h_est = 1.01 v_est = -1.96

はい、パラメータは非常に正確です。ファイル システムを確認すると、最終モデルだけでなく、最後の 4 回の反復から保存されたファイルがあります。

モデルを保存するときに、保存に必要なファイルが 4 種類あることに気付くでしょう。

  • 「.meta」ファイル: グラフ構造が含まれます。
  • 「.data」ファイル: 変数の値が含まれます。
  • 「.index」ファイル: チェックポイントを識別します。
  • 「チェックポイント」ファイル: 最近のチェックポイントのリストを含むプロトコル バッファー。

図1: ディスクに保存されたチェックポイントファイル

すべての変数をファイルに保存するには、上記のように tf.train.Saver() メソッドを呼び出します。変数のサブセットをリストまたは辞書として渡して保存します。例: tf.train.Saver({'hor_estimate': h_est})。

プロセス全体を制御できる Saver コンストラクターのその他の便利なパラメーターは次のとおりです。

  • max_to_keep: 保持するチェックポイントの最大数。
  • keep_checkpoint_every_n_hours: チェックポイントを保存する時間間隔。さらに詳しく知りたい場合は、Saver クラスの公式ドキュメントを参照してください。このドキュメントには、他にも役立つ情報が記載されています。
  • モデルの復元

TensorFlow モデルを復元するときに最初に行うことは、「.meta」ファイルから現在のグラフにグラフ構造を読み込むことです。

  1. tf.reset_default_graph()
  2. インポートされたメタ= tf .train.import_meta_graph("model_final.meta")

tf.get_default_graph() を使用して現在のグラフを探索することもできます。 2 番目のステップは、変数の値をロードすることです。注意: 値はセッション内にのみ存在します。

  1. tf.Session() を sess として使用:
  2. インポートされたメタデータを復元します(sess、tf.train.latest_checkpoint('./'))
  3. h_est2 = sess .run('hor_estimate:0')
  4. v_est2 = sess .run('ver_estimate:0')
  5. print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2))
  1. $ python tf_restore.py
  2. INFO:tensorflow:./model_final からパラメータを復元しています
  3. h_est: 1.01、v_est: -1.96

前述したように、このアプローチではグラフ構造と変数のみが保存されるため、プレースホルダー「X」と「Y」を通じて入力されたトレーニング データは保存されません。

とにかく、この例では、定義したトレーニング データ tf を使用して、モデルの適合を視覚化します。

  1. plt.scatter(x_train, y_train, label = 'トレーニングデータ' )
  2. plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2,= '赤' ラベル= 'モデル' )
  3. plt.xlabel('x_train')
  4. plt.ylabel('y_train')
  5. plt.凡例()

Saver クラスを使用すると、TensorFlow モデル (グラフと変数) をファイルに簡単に保存および復元したり、作業の複数のチェックポイントを保存したりできるため、トレーニング中にモデルを微調整するのに役立ちます。

4. SavedModel 形式

TensorFlow でモデルを保存および復元する新しい方法は、SavedModel、Builder、および loader 関数を使用することです。このメソッドは、実際には Saver によって提供される高レベルのシリアル化であり、ビジネス目的に適しています。

この SavedModel アプローチは開発者に完全に受け入れられているようには見えませんが、作成者は「これは明らかに未来だ」と述べています。主に変数に焦点を当てた Saver クラスと比較して、SavedModel は、Signatures (入力と出力のセットを含むグラフを保存できるようにする) や Assets (初期化に使用される外部ファイルを含む) など、いくつかの便利な機能を 1 つのパッケージに含めるようにしています。

(1) SavedModel Builderを使用してモデルを保存する

次に、SavedModelBuilder クラスを使用してモデルを保存してみます。この例では、シンボルは使用していませんが、プロセスを説明するには十分です。

  1. tf.reset_default_graph()
  2. # 2つの変数を再初期化する
  3. h_est = tf .Variable(h_est2,名前= 'hor_estimate2' )
  4. v_est = tf .Variable(v_est2、名前= 'ver_estimate2' )
  5.  
  6. # ビルダーを作成する
  7. ビルダー= tf .saved_model.builder.SavedModelBuilder('./SavedModel/')
  8.  
  9. # グラフと変数をビルダーに追加して保存する
  10. tf.Session() を sess として使用:
  11. sess.run(h_est.initializer)
  12. sess.run(v_est.initializer)
  13. ビルダー.add_meta_graph_and_variables(sess,
  14. [tf.saved_model.tag_constants.TRAINING]、
  15. signature_def_map =なし
  16. 資産コレクション=なし)
  17. ビルダー.save()
  1. $ python tf_saved_model_builder.py
  2. INFO:tensorflow:保存するアセットがありません。
  3. INFO:tensorflow:書き込むアセットがありません。
  4. INFO:tensorflow:SavedModel が次の場所に書き込まれました: b'./SavedModel/saved_model.pb'

このコードを実行すると、モデルが「./SavedModel/saved_model.pb」にあるファイルに保存されていることがわかります。

(2)SavedModel Loaderプログラムを使用してモデルを復元する

モデルの復元では tf.saved_model.loader が使用され、セッション スコープに保存された変数とシンボルを復元できます。

次の例では、モデルをロードし、2 つの係数 (h_est と v_est) の値を出力します。値は予想どおりで、モデルは正常に回復されました。

  1. tf.Session() を sess として使用:
  2. tf.saved_model.loader.load(sess、[tf.saved_model.tag_constants.TRAINING], './SavedModel/') をロードします。
  3. h_est = sess .run('hor_estimate2:0')
  4. v_est = sess .run('ver_estimate2:0')
  5. print("h_est: %.2f, v_est: %.2f" % (h_est, v_est))
  1. $ python tf_saved_model_loader.py
  2. INFO:tensorflow:b'./SavedModel/variables/variables' からパラメータを復元しています
  3. h_est: 1.01、v_est: -1.96

5. 結論

ディープラーニング ネットワークのトレーニングに長い時間がかかる可能性がある場合は、TensorFlow モデルの保存と復元が非常に役立ちます。このトピックは範囲が広すぎるため、1 つのブログ投稿で詳細を説明することはできません。とにかく、この投稿では、Saver と SavedModel ビルダー/ローダーという 2 つのツールを紹介し、ファイル構造を作成し、単純な線形回帰を使用して例を説明しました。これらが、より優れたニューラル ネットワーク モデルのトレーニングに役立つことを願っています。

<<:  リソースインベントリ: 便利な自動データサイエンスおよび機械学習ソフトウェア

>>:  第一回美団クラウド人工知能サミットが開幕、エコパートナーと協力して最もオープンなAIプラットフォームを構築

ブログ    

推薦する

Quora は機械学習をどのように活用していますか?

[[202181]] 2015年、同社のエンジニアリング担当副社長であるXavier Amatri...

ニュースローン賞受賞者 宋 樹蘭: 視覚の観点からロボットの「目」を構築する

この記事はLeiphone.comから転載したものです。転載する場合は、Leiphone.com公式...

今日は秋分の日で収穫の季節。ドローンがショーの中心です。

9月22日は秋分の日であり、私の国では3回目の「農民の収穫祭」でもあります。収穫の季節と重なる黄金...

推奨に値する 7 つの優れたオープンソース AI ライブラリ

[[406029]] [51CTO.com クイック翻訳]人工知能 (AI) 研究の分野では、Ten...

...

ケーススタディ | 埋め込みに基づく特徴セキュアな計算

[[331789]]序文従来のデータの公開と共有の方法の多くは、生のデータをプレーンテキストで直接出...

コミック版:ディープラーニングって何?

Google はどのようにしてわずか数秒で Web ページ全体をさまざまな言語に翻訳するのか、ある...

...

Chain World: シンプルで効果的な人間行動エージェントモデル強化学習フレームワーク

強化学習は、エージェントが環境と対話し、蓄積された報酬を最大化するために最適なアクションを選択する方...

ローコード プラットフォームに関する不完全な推奨事項!

ソフトウェア開発者向けのローコード機能それでは、ソフトウェア開発者に機械学習機能を提供するローコード...

...

機械学習愛好家必読ガイド

[[273182]]このガイドは、機械学習 (ML) に興味があるが、どこから始めればよいかわからな...

GPT-4 パラメータは 10 兆に達します!この表は、新しい言語モデルのパラメータが GPT-3 の 57 倍になると予測しています。

機械学習の場合、パラメータはアルゴリズムの鍵となります。パラメータは、履歴入力データであり、モデルト...

...