Githubには13,000個のスターがある。JAXの急速な発展はTensorFlowやPyTorchに匹敵する

Githubには13,000個のスターがある。JAXの急速な発展はTensorFlowやPyTorchに匹敵する

  [[416349]]

機械学習の分野では、TensorFlow と PyTorch は誰もがよく知っているかもしれませんが、これら 2 つのフレームワークに加えて、Google が立ち上げた JAX という新たな勢力も見逃せません。多くの研究者は、TensorFlow などの多くの機械学習フレームワークを置き換えることができると期待し、大きな期待を寄せています。

JAX はもともと、Google Brain チームの Matt Johnson、Roy Frostig、Dougal Maclaurin、Chris Leary によって開始されました。

現在、JAX は GitHub で 13.7K 個のスターを獲得しています。

プロジェクトアドレス: https://github.com/google/jax

JAXの急速な発展

JAX の前身は Autograd です。Autograd の更新版の助けを借りて、XLA と組み合わせることで、Python プログラムと NumPy 操作の自動微分を実行し、ループ、分岐、再帰、クロージャ関数の導出、および 3 次導関数をサポートできます。XLA に依存することで、JAX は GPU と TPU で NumPy プログラムをコンパイルして実行できます。grad を通じて、自動モードのバックプロパゲーションとフォワードプロパゲーションをサポートでき、2 つを任意の順序で組み合わせることができます。

JAX 開発の出発点は何でしたか?これについて言えば、NumPy について触れなければなりません。 NumPy は Python の基本的な数値計算ライブラリであり、広く使用されています。ただし、NumPy は GPU やその他のハードウェア アクセラレータをサポートしておらず、バックプロパゲーションのサポートも組み込まれていません。さらに、Python 自体の速度制限により NumPy の使用が妨げられるため、NumPy を直接使用してディープラーニング モデルを実稼働環境でトレーニングまたは展開する研究者はほとんどいません。

このような状況の中で、PyTorch、TensorFlow など、数多くのディープラーニング フレームワークが登場しました。ただし、numpy には柔軟性、デバッグの容易さ、安定した API などの独自の利点があります。 JAX の主な出発点は、numpy の上記の利点とハードウェア アクセラレーションを組み合わせることです。

現在、JAX をベースにした優れたオープンソース プロジェクトが数多く存在します。たとえば、Google のニューラル ネットワーク ライブラリ チームは、Jax 用のディープラーニング コード ライブラリである Haiku を開発しました。Haiku を通じて、ユーザーは Jax 上でオブジェクト指向開発を行うことができます。もう 1 つの例は、Jax をベースにした強化学習ライブラリである RLax です。ユーザーは RLax を使用して Q 学習モデルを構築およびトレーニングできます。さらに、1 行のコードで計算グラフを定義し、GPU アクセラレーションを実行できる JAX ベースのディープラーニング ライブラリ JAXnet もあります。ここ数年、JAXはディープラーニング研究に旋風を巻き起こし、科学研究の急速な発展を促進してきたと言えます。

JAX のインストール

JAX の使い方は?まず、Python 環境または Google Colab に JAX をインストールする必要があります。pip を使用してインストールします。

  1. $ pip インストール --upgrade jax jaxlib

上記のインストール方法は、CPU 上での実行のみをサポートしていることに注意してください。プログラムを GPU 上で実行する場合は、まず CUDA と cuDNN が必要で、その後次のコマンドを実行します (jaxlib バージョンを CUDA バージョンにマッピングするようにしてください)。

  1. $ pip インストール --upgrade jax jaxlib == 0.1.61 +cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

次に、Numpy とともに JAX をインポートします。

  1. jaxをインポートする
  2. jax.numpyをjnpとしてインポートする
  3. numpyをnpとしてインポートする

JAXの機能

grad() 関数を使用した自動微分: これはバックプロパゲーションの実行を容易にするため、ディープラーニング アプリケーションに非常に役立ちます。以下は、単純な 2 次関数の例で、ポイント 1.0 で導関数を取得します。

  1. jaxインポートgrad から
  2. 定義f(x):
  3. 3 *x** 2 + 2 *x + 5を返す
  4. f_prime(x)を定義します:
  5. 6 *x + 2を返す
  6. 卒業率( 1.0 )
  7. # デバイス配列( 8. , dtype=float32)
  8. f_prime( 1.0 )
  9. # 8.0

jit (ジャストインタイム): XLA のパワーを活用するには、コードを XLA カーネルにコンパイルする必要があります。ここで JIT が役立ちます。 XLA と jit を使用するには、ユーザーは jit() 関数または @jit アノテーションを使用できます。

  1. jaxからjitをインポート
  2. x = np.random.rand( 1000 , 1000 )
  3. y = jnp.array(x)
  4. 定義f(x):
  5. _ が範囲( 10 )内にある場合:
  6. x = 0.5 * x + 0.1 * jnp.sin(x)
  7. xを返す
  8. g = jit(f)
  9. %timeit -n 5 -r 5 f(y).block_until_ready()
  10. # 5ループ、ベスト5 : ループあたり10.8ミリ秒
  11. %timeit -n 5 -r 5 g(y).block_until_ready()
  12. # 5ループ、ベスト5 : ループあたり341 µs

pmap: 現在のすべてのデバイスに計算を自動的に分散し、それらの間のすべての通信を処理します。 JAX は pmap 変換を通じて大規模なデータ並列処理をサポートし、単一のプロセッサでは処理できない大規模なデータを処理します。利用可能なデバイスを確認するには、jax.devices() を実行します。

  1. jaxからpmapをインポート
  2. 定義f(x):
  3. jnp.sin(x) + x** 2を返す
  4. f(np.arange( 4 )) は、
  5. #デバイス配列([ 0 . , 1.841471 , 4.9092975 , 9.14112 ], dtype=float32)
  6. pmap(f)(np.arange( 4 ))
  7. #ShardedDeviceArray([ 0 . , 1.841471 , 4.9092975 , 9.14112 ], dtype=float32)

vmap: 関数変換です。JAX は vmap 変換による自動ベクトル化アルゴリズムを提供します。これにより、このタイプの計算が大幅に簡素化され、研究者はバッチの問題に悩まされることなく新しいアルゴリズムを扱えるようになります。次に例を示します。

  1. jaxからvmapをインポート
  2. 定義f(x):
  3. jnp.square(x)を返す
  4. f(jnp.arange( 10 ))
  5. #デバイス配列([ 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 ], dtype=int32)
  6. vmap(f)(jnp.arange( 10 ))
  7. #デバイス配列([ 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 ], dtype=int32)

TensorFlow 対 PyTorch 対 Jax

ディープラーニングの分野には巨大企業がいくつもあり、彼らが提案するフレームワークは多くの研究者に利用されています。たとえば、Google の TensorFlow、Facebook の PyTorch、Microsoft の CNTK、Amazon AWS の MXnet などです。

各フレームワークには長所と短所があり、自分のニーズに応じて選択する必要があります。

Python の 3 つの主要なディープラーニング フレームワーク (TensorFlow、PyTorch、Jax) を比較します。これらのフレームワークは異なりますが、共通点が 2 つあります。

  • それらはオープンソースです。つまり、ライブラリにバグがある場合、ユーザーは GitHub で問題を報告して修正してもらうことができ、また、独自の機能をライブラリに追加することもできます。
  • Python は、グローバル インタープリタ ロックが原因で内部的に遅く実行されます。したがって、これらのフレームワークは、すべての計算と並列プロセスを処理するために、バックエンドとして C/C++ を使用します。

では、どのような点が異なるのでしょうか?次の表は、TensorFlow、PyTorch、JAX の 3 つのフレームワークの比較を示しています。

テンソルフロー

TensorFlow は Google によって開発され、その最初のバージョンは 2015 年のオープンソースの TensorFlow0.1 にまで遡ります。それ以来、着実に発展し、強力なユーザーベースを持ち、最も人気のあるディープラーニング フレームワークになりました。しかし、使用してみると、API の安定性が不十分であったり、静的計算グラフ プログラミングが複雑であったりするなど、TensorFlow の欠点も明らかになりました。そのため、TensorFlow 2.0 バージョンでは、Google が Keras を組み込み、tf.keras になりました。

TensorFlow の主な機能は次のとおりです。

  • これは非常にユーザーフレンドリーなフレームワークです。高レベル API-Keras が利用できるため、モデル レイヤーの定義、損失関数、モデルの作成が非常に簡単になります。
  • TensorFlow 2.0 には Eager Execution が付属しており、これによりライブラリがよりユーザーフレンドリーになり、以前のバージョンから大幅にアップグレードされています。
  • この高レベル インターフェースには、いくつかの欠点があります。TensorFlow は、エンド ユーザーの利便性のためだけに、多くの基礎となるメカニズムを抽象化しているため、研究者はモデルを処理する自由度が低くなります。
  • Tensorflow は、実際には Tensorflow 視覚化ツールキットである TensorBoard を提供します。これにより、研究者は損失関数、モデルグラフ、モデル分析などを視覚化できます。

パイトーチ

PyTorch (Python-Torch) は、Facebook の機械学習ライブラリです。 TensorFlow か PyTorch か? 1 年前、この質問には異論はなく、ほとんどの研究者が TensorFlow を選択しました。しかし、今では状況は大きく変わり、PyTorch を使用する研究者が増えています。 PyTorch の最も重要な機能には次のようなものがあります。

  • TensorFlow とは異なり、PyTorch は動的型グラフを使用します。つまり、実行グラフはオンザフライで作成されます。いつでもグラフの内部構造を変更したり検査したりすることができます。
  • PyTorch には、ユーザーフレンドリーな高レベル API に加えて、機械学習モデルをより細かく制御できるように慎重に構築された低レベル API も含まれています。トレーニング中に、モデルの前方パスと後方パスの両方の出力を検査および変更できます。これは、グラデーション クリッピングとニューラル スタイル転送に非常に効果的であることが示されています。
  • PyTorch を使用すると、ユーザーはコードを拡張して、新しい損失関数やユーザー定義のレイヤーを簡単に追加できます。 PyTorch の Autograd モジュールは、ディープラーニング アルゴリズムにバックプロパゲーション微分を実装します。Tensor クラスのすべての操作に対して、Autograd は微分を自動的に提供し、手動で微分を計算する複雑なプロセスを簡素化します。
  • PyTorch は、データ並列処理と GPU の使用を幅広くサポートしています。
  • PyTorch は TensorFlow よりも Python 的です。 PyTorch は Python エコシステムにうまく適合し、Python のようなデバッガー ツールを使用して PyTorch コードをデバッグできます。

ジャックス

JAX は、Google の比較的新しい機械学習ライブラリです。これは、ネイティブ Python と NumPy コードを区別できる autograd ライブラリのようなものです。 JAX の主な機能は次のとおりです。

  • 公式ウェブサイトに記載されているように、JAX は Python + NumPy プログラムの構成可能な変換 (ベクトル化、JIT から GPU/TPU など) を実行できます。
  • PyTorch と比較した JAX の最も重要な側面は、勾配の計算方法です。 Torch では、グラフはフォワード パス中に作成され、勾配はバックワード パス中に計算されますが、一方、JAX では計算は関数として表現されます。関数に grad() を使用すると、指定された入力に対する関数の勾配を直接計算する勾配関数が返されます。
  • JAX は自動グレード ツールであり、単独での使用は推奨されません。 JAX ベースの機械学習ライブラリはさまざまありますが、その中でも注目すべきものとしては ObJax、Flax、Elegy などがあります。これらはすべて同じコアを使用し、インターフェースは JAX ライブラリのラッパーにすぎないため、同じ括弧内に配置できます。
  • Flax はもともと PyTorch エコシステムの下で開発され、使用の柔軟性に重点を置いていました。一方、Elegy は Keras からインスピレーションを受けています。 ObJAX は、シンプルさとわかりやすさを重視し、主に研究指向の目的で設計されています。

<<:  人工知能の「想像力」を実現する

>>:  人工知能に関する世界インターネット会議の8つの視点のレビュー

ブログ    
ブログ    
ブログ    

推薦する

住宅建設はよりスマートになる

スマートホーム革命はここしばらく本格的に始まっています。住宅所有者はデータと IoT テクノロジーを...

...

テレンス・タオが、60 年前のもう一つの幾何学の問題に取り組みます。周期的タイル分割問題における新たなブレークスルー

テレンス・タオ氏が研究してきた周期的モザイク化問題に新たな進歩がありました。 9月18日、Teren...

AIによって次に職を奪われるのは字幕作成者でしょうか?

2016年頃から、多くのメディアが「どの仕事がAIに置き換えられるか」を予測し始めたとぼんやりと記...

...

ディープマインドAIは人間に対して84%の勝率を誇り、ウエスタンアーミーチェスで初めて人間の専門家のレベルに到達した。

DeepMind はゲーム AI の分野で新たな成果を上げました。今回はチェスです。 AI ゲーム...

...

ビジネスに適したRPAソフトウェアの選び方

[[407278]] RPA(ロボティック・プロセス・オートメーション)は、ビジネスユーザーを退屈で...

機械学習モデルの品質を保証し、その有効性を評価する方法

[[396139]]近年、機械学習モデルアルゴリズムは、ますます多くの産業実践に実装されるようになり...

5G自動運転車が景勝地でデビュー、商用利用のシナリオも間もなく登場

[[264714​​]]最近、5G携帯電話や5G商用利用に関するニュースが多く出ています。国内外の多...

トピックモデルに適した定量評価指標を見つけるにはどうすればよいでしょうか?これは人気のある方法の要約です

LDA (潜在的ディリクレ分布) や Biterm などの統計トピック モデルを適用することで、大量...

...

柯潔はなぜ「負けてカッとなった」と言ったのか!人間対機械の第一ラウンドを説明する8つの質問

4時間以上の対局の末、柯潔はAlphaGoに0.25ポイント差で負けた。対局後、アルファ碁の指導に参...