TFとPyTorchだけを知っているだけでは不十分です。PyTorchから自動微分ツールJAXに切り替える方法を見てみましょう。

TFとPyTorchだけを知っているだけでは不十分です。PyTorchから自動微分ツールJAXに切り替える方法を見てみましょう。

現在のディープラーニング フレームワークに関しては、TensorFlow と PyTorch を避けることができないことがよくあります。しかし、これら 2 つのフレームワークに加えて、いくつかの新しい力も過小評価すべきではありません。その 1 つが JAX です。前方および後方の自動微分機能を備えており、高次導関数の計算に非常に優れています。この注目のフレームワークはどれほど便利なのでしょうか? ニューラル ネットワーク内の複雑な勾配更新とバックプロパゲーションを実証するには、どのように使用できるでしょうか? この記事は、Jax の基礎となるロジックを理解し、PyTorch や他のフレームワークからの移行を容易にするためのチュートリアル記事です。

[[326161]]

Jax は、機械学習と数学計算用に Google が開発した Python ライブラリです。起動すると、Jax は Python + NumPy パッケージとして定義されました。微分化、ベクトル化、TPU および GPU での JIT 言語の使用などの機能があります。つまり、これは自動微分も可能な numpy の GPU バージョンです。 Skye Wanderman-Milne 氏のような研究者も、昨年の NeurlPS 2019 カンファレンスで Jax を紹介しました。

しかし、開発者にとって、すでに使い慣れている PyTorch や TensorFlow 2.X から Jax に切り替えることは間違いなく大きな変化です。この 2 つは計算とバックプロパゲーションの構築方法が根本的に異なるからです。 PyTorch は計算グラフを構築し、順方向パスと逆方向パスを計算します。結果ノードの勾配は、中間ノードの勾配から累積されます。

Jax は違います。計算を Python 関数として表現し、grad() を使用して評価可能な勾配関数に変換します。しかし、結果ではなく、結果の勾配が示されます。両者の比較を以下に示します。

これにより、プログラミングとモデルの構築方法が変わります。したがって、テープベースの自動差別化方法を使用し、ステートフル オブジェクトを使用できます。しかし、Jax は grad() 関数を実行するときに微分処理を関数のように動作させるので、驚くかもしれません。

おそらく、flax、trax、haiku などの Jax ベースのツールを検討してみることにしたでしょう。 ResNet のような例を見ると、他のフレームワークのコードとは異なることがわかります。レイヤーを定義してトレーニングを実行する以外に、基礎となるロジックは何ですか? これらの小さな NumPy プログラムはどのようにして巨大なアーキテクチャをトレーニングするのでしょうか?

この記事は、Jax を使用してモデルを構築する方法に関するチュートリアルです。Machine Heart では、次の 2 つの部分を取り上げました。

  • PyTorch での LSTM-LM アプリケーションの簡単なレビュー。
  • PyTorch スタイルのコード (mutate 状態に基づく) を見て、純粋な関数がモデルを構築する方法 (Jax) を学びます。

PyTorch 上の LSTM 言語モデル

まず、PyTorch を使用して LSTM 言語モデルを実装します。コードは次のとおりです。

  1. 輸入トーチ
  2. クラス LSTMCell(torch.nn.Module):
  3. def __init__(self, in_dim, out_dim):
  4. super(LSTMCell、self).__init__()
  5. self.weight_ih = torch.nn.Parameter (torch.rand(4*out_dim, in_dim))
  6. self.weight_hh = torch.nn.Parameter (torch.rand(4*out_dim, out_dim))
  7. 自己バイアス= torch.nn.パラメータ(torch.zeros(4*out_dim,))
  8.  
  9. def forward(自己、入力、h、c):
  10. ifgo = self .weight_ih@inputs + self .weight_hh@h + self .bias
  11. i、f、g、 o = torch.chunk (ifgo、4)
  12. i =トーチ.シグモイド(i)
  13. f =トーチシグモイド(f)
  14. g =トーチ.tanh (g)
  15. o =トーチ.シグモイド(o)
  16. new_c = f * c + i * g
  17. new_h = o * torch.tanh(new_c)
  18. 戻り値 (new_h, new_c)

次に、この LSTM ニューロンに基づいて単層ネットワークを構築します。ここには埋め込みレイヤーがあり、学習可能な (h,c)0 と組み合わせることで、個々のパラメータがどのように変化するかを示します。

  1. クラス LSTMLM(torch.nn.Module):
  2. def __init__(self, vocab_size, dim = 17 ):
  3. スーパー().__init__()
  4. セルをLSTMCell囲みます。
  5. 自己埋め込み= torch.nn.パラメータ(torch.rand(vocab_size, dim))
  6. self.c_0 = torch.nn.パラメータ(torch.zeros(dim))
  7.  
  8. @財産
  9. hc_0(自分)を定義します。
  10. 戻り値 (torch.tanh(self.c_0), self.c_0)
  11.  
  12. def forward(self, seq, hc):
  13. 損失= torch.tensor (0.)
  14. seq内のidxの場合:
  15. 損失- = torch.log_softmax (self.embeddings@hc[0], dim =-1)[idx]
  16. hc =自己.セル(自己.埋め込み[idx,:], *hc)
  17. リターンロス、HC
  18.  
  19. greedy_argmax(self, hc,長さ= 6 )を定義します:
  20. torch.no_grad() の場合:
  21. idxs = []
  22. i が範囲(長さ)内にある場合:
  23. idx = torch.argmax (self.embeddings@hc[0])
  24. idxs.append(idx.item())
  25. hc =自己.セル(自己.埋め込み[idx,:], *hc)
  26. idxを返す

構築後、トレーニング:

  1. トーチ.マニュアル_シード(0)
  2. # トレーニングデータとして、単語/単語部分/文字のインデックスを用意します。
  3. # これらはトークン化され、整数化されていると仮定します (もちろん、これはおもちゃの例です)。
  4. jax.numpyをjnpとしてインポートする
  5. vocab_size = 43 # プライムトリック! :)
  6. トレーニングデータ= jnp.array ([4, 8, 15, 16, 23, 42])
  7.  
  8. lm = LSTMLM (語彙サイズ語彙サイズ= 語彙サイズ)
  9. print("前のサンプル:", lm.greedy_argmax(lm.hc_0))
  10.  
  11. bptt_length = 3 # hc.detach-ingを説明するため
  12.  
  13. 範囲(101)のエポックの場合:
  14. hc = lm.hc_0
  15. 総損失= 0
  16. 開始範囲(0、len(training_data)、bptt_length)の場合:
  17. バッチ=トレーニングデータ[開始:開始+bptt_length]
  18. 損失、(h, c) = lm(バッチ、hc)
  19. hc = (h.detach()、c.detach()) です。
  20. エポック% 50 == 0の場合:
  21. 総損失 += 損失.item()
  22. 損失.後方()
  23. lm.named_pa​​rameters() の name、param の場合:
  24. param.gradがNoneでない場合:
  25. パラメータデータ- = 0.1 * パラメータ勾配
  26. del param.grad
  27. 全損の場合:
  28. print("損失:", totalloss)
  29.           
  30. print("後のサンプル:", lm.greedy_argmax(lm.hc_0))
  31. 以前のサンプル: [42, 34, 34, 34, 34, 34]
  32. 損失: 25.953862190246582
  33. 損失: 3.7642268538475037
  34. 損失: 1.9537211656570435
  35. 後のサンプル: [4, 8, 15, 16, 23, 42]

ご覧のとおり、PyTorch コードは比較的明確ですが、まだいくつか問題があります。非常に注意していますが、計算グラフ内のノードの数に注意を払うことは依然として重要です。これらの中間ノードは適切なタイミングでクリアする必要があります。

純粋関数

JAX がこれをどのように処理するかを理解するには、まず純粋関数の概念を理解する必要があります。これまでに関数型プログラミングを行ったことがある場合、純粋関数は数学の関数や数式のようなものだという概念に馴染みがあるかもしれません。特定の入力値から出力値を取得する方法を定義します。重要なのは、この関数には「副作用」がないこと、つまり関数のどの部分もグローバル状態にアクセスしたり変更したりしないことです。

Pytorch で記述するコードには中間変数や状態が多数含まれており、これらの状態は頻繁に変化するため、推論と最適化が非常に難しくなります。したがって、JAX はプログラマーを純粋関数のスコープ内に制限し、上記の状況が発生しないようにすることを選択します。

JAX について詳しく説明する前に、純粋関数の例をいくつか見てみましょう。純粋関数は次の条件を満たす必要があります。

  • 関数を実行する状況と実行時期は出力に影響を与えません。入力が変わらない限り、出力も変わりません。
  • 関数を 0 回実行したか、1 回実行したか、あるいはそれ以上実行したかは、後で区別がつかなくなるはずです。

次の不純な関数は、上記の条件の少なくとも 1 つに違反します。

  1. ランダムにインポート
  2. インポート時間
  3. 実行回数= 0  
  4.  
  5. pure_fn_1(x)を定義します:
  6. 2 * xを返す
  7.      
  8. pure_fn_2(xs)を定義します:
  9. ys = []
  10. x が xs 内にある場合:
  11. # 関数 *内部* で状態のある変数を変更しても問題ありません。
  12. ys.append(2 * x)
  13. ysを返す
  14.  
  15. 不純なfn_1(xs)を定義します:
  16. # 引数を変更すると、関数の外部に永続的な影響が生じます! :(
  17. xs.append(合計(xs))
  18. xsを返す
  19.  
  20. 不純なfn_2(x)を定義します:
  21. # 明らかに変異している
  22. 地球の状態は悪いです...地球
  23. 実行回数 nr_executions += 1
  24. 2 * xを返す
  25.  
  26. 不純なfn_3(x)を定義します:
  27. # ...しかし、アクセスするだけでも、関数は
  28. # 実行コンテキスト!
  29. nr_executions * x を返す
  30.  
  31. 不純なfn_4(x)を定義します:
  32. # IO のようなものは不純度の典型的な例です。
  33. # 次の 3 行はすべて純粋性の違反です。
  34. print("こんにちは!")
  35. ユーザー入力= 入力()
  36. 実行時間= time.time()
  37. 2 * xを返す
  38.  
  39. 不純なfn_5(x)を定義します:
  40. # これはどちらの制約に違反しますか? 実は両方です! 現在の
  41. # ランダム性の状態 *そして* 数値ジェネレーターを進化させます!
  42. p =ランダム.ランダム()
  43. p * x を返す
  44. JAX が操作する純粋な関数を見てみましょう。導入図の例です。
  45.  
  46. # (ほぼ)1次元線形回帰
  47. f(w, x)を定義します。
  48. w * xを返す
  49.  
  50. 印刷(f(13., 42.))
  51. 546.0

今のところ何も起こっていません。 JAX を使用すると、次の関数を、結果を返す代わりに、関数の最初の引数に対する関数の結果の勾配を返す別の関数に変換できるようになりました。

  1. jaxをインポートする
  2. jax.numpyをjnpとしてインポートする
  3.  
  4. # 勾配: 重みに関して! JAX はデフォルトで最初の引数を使用します。
  5. df_dw = jax.grad (f)
  6.  
  7. def manual_df_dw(w, x):
  8. xを返す
  9.      
  10. df_dw(13., 42.) == manual_df_dw(13., 42.) をアサートする
  11.  
  12. 印刷(df_dw(13., 42.))
  13. 42.0

ここまでで、JAX README ドキュメントの内容をすべてご覧になったと思いますが、その内容は妥当なものです。しかし、PyTorch コードのように大きなモジュールにジャンプするにはどうすればよいでしょうか?

まず、バイアス項を追加し、1 次元の線形回帰変数を、使い慣れたオブジェクト (LinearRegressor「レイヤー」) にラップしてみます。

  1. クラス LinearRegressor():
  2. __init__(self, w, b)を定義します。
  3. 自己.w = w
  4. 自己.b = b
  5.      
  6. def predict(自己, x):
  7. self.w * x + self.b を返す
  8.          
  9. 定義rms(self, xs: jnp.ndarray, ys: jnp.ndarray):
  10. jnp.sqrt(jnp.sum(jnp.square(self.w * xs + self.b - ys))) を返します。
  11.          
  12. my_regressor =線形回帰(13., 0.)
  13.  
  14. # トレーニングに使用される損失関数の一種
  15. xs = jnp .array([42.0])
  16. ys = jnp.array ([500.0])
  17. 印刷(my_regressor.rms(xs, ys))
  18.  
  19. # テストデータの予測
  20. 印刷(my_regressor.predict(42.))
  21. 46.0
  22. 546.0

トレーニングに勾配をどのように使用するのでしょうか? モデルの重みを入力パラメータとして受け取る純粋な関数が必要です。次のようなものになります。

  1. loss_fn(w, b, xs, ys)を定義します。
  2. my_regressor =線形回帰(w, b)
  3. my_regressor.rms( xs xs =xs, ys ys =ys)を返します
  4.      
  5. # argnums = (0, 1) を使ってJAXに渡すよう指示します
  6. # 最初のパラメータと 2 番目のパラメータに関するグラデーション。
  7. grad_fn = jax.grad (loss_fn、引数=(0, 1))
  8.  
  9. 印刷(loss_fn(13., 0., xs, ys))
  10. print(grad_fn(13., 0., xs, ys))
  11. 46.0
  12. (デバイス配列(42., dtype = float32 ), デバイス配列(1., dtype = float32 ))

これが正しいのだと自分自身を納得させなければなりません。さて、これは機能しますが、明らかに、 loss_fn の定義部分ですべてのパラメーターを列挙することは実現可能ではありません。

幸いなことに、JAX はスカラー、ベクトル、行列だけでなく、多くのツリーのようなデータ構造も区別できます。この構造は pytree と呼ばれ、Python 辞書で構成されています。

  1. loss_fn(パラメータ、xs、ys)を定義します。
  2. my_regressor =線形回帰(params['w'], params['b'])
  3. my_regressor.rms( xs xs =xs, ys ys =ys)を返します
  4.  
  5. grad_fn = jax.grad (損失_fn)
  6.  
  7. 印刷(loss_fn({'w': 13., 'b': 0.}, xs, ys))
  8. print(grad_fn({'w': 13., 'b': 0.}, xs, ys))
  9. 46.0
  10. {'b': DeviceArray(1., dtype = float32 ), 'w': DeviceArray(42., dtype = float32 )} これで見た目も良くなりました! トレーニング ループは次のように記述できます。

これで見た目はずっと良くなりました! トレーニング ループは次のように記述できます。

  1. パラメータ= {'w': 13., 'b': 0.}
  2.  
  3. _ が範囲内(15)の場合:
  4. 印刷(loss_fn(params, xs, ys))
  5. grads = grad_fn (パラメータ、xs、ys)
  6. params.keys() 内の名前の場合:
  7. params[名前] - = 0.002 * grads[名前]
  8.          
  9. # さて、予測してみましょう:
  10. 線形回帰(params['w'], params['b']).predict(42.)
  11. 46.0
  12. 42.47003
  13. 38.940002
  14. 35.410034
  15. 31.880066
  16. 28.350098
  17. 24.820068
  18. 21.2901
  19. 17.760132
  20. 14.230164
  21. 10.700165
  22. 7.170166
  23. 3.6401978
  24. 0.110198975
  25. 3.4197998
  26. デバイス配列(500.1102、 dtype = float32 )

より多くの JAX ヘルパーを使用して自分自身を更新できるようになったことに注意してください。パラメーターとグラデーションには共通の (ツリーのような) 構造があるため、これらを一番上に配置し、次のように、どこでも値が 2 つのツリーの「組み合わせ」である新しいツリーを作成することが考えられます。

  1. update_combiner を定義します(param, grad, lr = 0.002 ):
  2. 戻り値パラメータ - lr * grad
  3.      
  4. パラメータ= jax.tree_multimap (update_combiner、パラメータ、グラデーション)
  5. # の代わりに:
  6. # params.keys() 内の名前:
  7. # params[名前] - = 0.1 * grads[名前]

参考リンク: https://sjmielke.com/jax-purify.htm

[この記事は51CTOコラム「Machine Heart」、WeChatパブリックアカウント「Machine Heart(id:almosthuman2014)」によるオリジナル翻訳です]

この著者の他の記事を読むにはここをクリックしてください

<<:  たった2時間で7元以下で3Dロボットが作れます

>>:  機械学習におけるモデルドリフト

ブログ    
ブログ    
ブログ    
ブログ    
ブログ    

推薦する

...

...

物理学者が67年前に予測した「悪魔」がネイチャー誌に登場:「偽の」高温超伝導体で偶然発見

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

AIを使って人の心を理解する?感情科学の専門家:表情から感情を識別するのは信頼できない

AIは人間の感情を認識できるでしょうか?原理的には、AIは音声認識、視覚認識、テキスト認識、表情認識...

Pythonを使用して機械学習モデルを作成する方法

導入新しいモデルをトレーニングしたときに、Flask コード (Python Web フレームワーク...

Facebook が人工知能を活用する 6 つの方法 (予想外のものもいくつかある)

[51CTO.com クイック翻訳] Facebook は人工知能を使用してポルノを識別し、マーク...

人工知能倫理ガバナンスは早急に実践段階へ移行する必要がある

今日の社会では、デジタル工業化と産業のデジタル化により、デジタル世界と物理世界の深い融合と発展が促進...

マスク氏:プログラマーの62%が人工知能が武器化されると考えている

常に人工知能の脅威論を支持してきたシリコンバレーの「鉄人」マスク氏は、今回、プログラマーたちの間で支...

...

rsyncのコアアルゴリズム

Rsync は、Unix/Linux でファイルを同期するための効率的なアルゴリズムです。2 台のコ...

トレンドにおける危険とチャンス: 生成 AI の黄金期をどう捉えるか?

ChatGPTは今年9月末に音声チャットと画像認識機能を追加しました。テキスト駆動型と比較して、C...

...

マスク氏「高度なAIの開発は非常にリスクが高い。OpenAIはアルトマン氏を解雇した理由を明らかにすべき」

11月20日、テスラのCEOイーロン・マスク氏は、高度な人工知能(AI)技術の開発には大きな潜在的...

2022年には大学卒業者数が1000万人を超えるが、AI関連の仕事の月給はたったの2万4000円?

2022年、伝説の「黄金の3月と銀の4月」がやって来ます... 「青銅三・鉄四」に変身しました… ...