2022年にJAXを使うべきでしょうか? GitHubには16,000個のスターがあるが、この若いツールは完璧ではない

2022年にJAXを使うべきでしょうか? GitHubには16,000個のスターがあるが、この若いツールは完璧ではない

2018 年後半の発売以来、JAX の人気は着実に高まっています。 2020年、DeepMindは研究を加速するためにJAXを使用することを発表しました。 Google Brain やその他の機関のプロジェクトでも JAX が使用されるケースが増えています。

現在、JAX GitHub プロジェクトのホームページでは、スターの数が 16.3k に達しています。​

プロジェクト アドレス: https://github.com/google/jaxJAX は非常に有望なプロジェクトであり、そのユーザー数は着実に増加しています。 JAX は、ディープラーニング、ロボット工学/制御システム、ベイズ法、科学的シミュレーションなど、多くの分野で広く使用されています。

これは、JAX も次世代の大規模ディープラーニング フレームワークになることを意味しますか?最近、AssemblyAI ブログに掲載された記事「2022 年に JAX を使用する必要がある (または使用すべきでない) 理由」で、著者の Ryan O'Connor 氏が JAX の概念、JAX を使用する理由、JAX を使用する必要があるかどうかについて詳細な解釈を示しました。

JAX の紹介

JAX はディープラーニングのフレームワークやライブラリではなく、またそのように設計されているわけでもありません。簡単に言えば、JAX は構成可能な関数変換で構成される数値計算ライブラリです。ご覧のとおり、ディープラーニングは JAX の機能のほんの一部にすぎません。

JAX は、科学計算と関数変換の相互統合であり、ディープラーニング モデルのトレーニング以外にも、次のようなさまざまな機能を備えています。

  • ジャストインタイムコンパイル
  • 自動並列化
  • 自動ベクトル化
  • 自動微分化

JAX を使用する理由は何ですか?

一言で言えば、スピードです。これは、あらゆるユースケースに関連する JAX の一般的な機能です。 NumPy と JAX を使用して、行列の最初の 3 つの累乗を (要素ごとに) 合計してみましょう。

まず最初は NumPy 実装です。この計算には約 851 ミリ秒かかることがわかりました。​

 

その後、計算は JAX を使用して実装されました。JAX はわずか 5.54 ミリ秒で計算を実行し、NumPy よりも 150 倍以上高速でした。

​JAX は NumPy よりも桁違いに高速です。 JAX は TPU を使用し、NumPy は CPU を使用することに注意してください。これは、JAX の速度制限が NumPy よりもはるかに高いことを強調するためです。

著者は、JAX を使用する理由として次の 6 つを挙げています。

  • NumPy アクセラレータ。 NumPy は Python による科学計算のための基本的なパッケージの 1 つですが、CPU とのみ互換性があります。 JAX は、GPU および TPU 上で非常に簡単に実行できる NumPy の実装 (ほぼ同じ API を使用) を提供します。多くのユーザーにとって、この機能だけで JAX の使用を正当化するのに十分です。
  • XL. XLA (Accelerated Linear Algebra) は、線形代数専用に設計されたプログラム全体の最適化コンパイラです。 JAX は XLA 上に構築されており、計算速度の上限を大幅に向上させます。
  • ジット。 JAX を使用すると、ユーザーは XLA を使用して関数をジャストインタイム (JIT) バージョンに変換できます。つまり、計算関数に単純な関数デコレータを追加することで、計算速度を数桁向上させることができます。
  • 自動微分化。 JAX は、Autograd (ネイティブ Python コードと NumPy コードを自動的に区別する) と XLA を組み合わせます。XLA の自動区別機能は、科学計算の多くの分野で非常に重要です。 JAX はいくつかの強力な自動微分化ツールを提供します。
  • ディープラーニング。 JAX 自体はディープラーニング フレームワークではありませんが、ディープラーニングの優れた基盤を提供します。 Flax、Haiku、Elegy など、ディープラーニング機能を提供することを目的とした、JAX 上に構築されたライブラリは数多くあります。最近の PyTorch や TensorFlow の記事でも、JAX は注目に値する「フレームワーク」として強調されており、TPU ベースのディープラーニング研究に推奨されています。 JAX のヘッセ行列の効率的な計算は、高次の最適化手法をより実現可能にするため、ディープラーニングにも関連しています。
  • 一般的な微分可能プログラミングパラダイム。 JAX を使用してディープラーニング モデルを構築およびトレーニングできるだけでなく、一般的な微分可能プログラミングのフレームワークも提供します。つまり、JAX はモデルベースの機械学習アプローチを使用して問題を解決し、数十年にわたる研究から蓄積された特定のドメインの事前知識を活用できます。​

JAX 変換

これまで、XLA と、XLA によって JAX がアクセラレータ上で NumPy を実装できる仕組みについて説明してきましたが、これは JAX 定義の半分にすぎないことに注意してください。 JAX は、強力な科学計算のためのツールだけでなく、構成可能な機能変換のためのツールも提供します。

たとえば、勾配関数変換をスカラー値関数 f(x) に適用すると、f(x) のドメイン内の任意の点における関数の勾配を与えるベクトル値関数 f'(x) が得られます。​

関数にgrad()を使用すると、ドメイン内の任意の点の勾配を取得できます。

JAX には、このような関数変換を次の 4 つの一般的な方法で実装するための拡張可能なシステムが含まれています。

  • Grad() は自動微分を実行します。
  • Vmap() 自動ベクトル化;
  • Pmap() は計算を並列化します。
  • Jit() は関数をジャストインタイムコンパイルバージョンに変換します。

grad() を使用した自動微分

機械学習モデルをトレーニングするにはバックプロパゲーションが必要です。 JAX では、Autograd と同様に、ユーザーは grad() 関数を使用して勾配を計算できます。

たとえば、関数 f(x) = abs(x^3) の微分は次のようになります。関数とその導関数を x=2 および x=-3 で評価すると、期待どおりの結果が得られることが分かります。​

では、grad() はどの程度まで微分化できるのでしょうか? JAX は、grad() を繰り返し適用することで微分化を容易にします。次のプログラムでわかるように、出力関数の 3 次導関数は、f'''(x)=6 という一定の期待出力を与えます。

grad() はどのような場面で使用できるのかと疑問に思う人もいるかもしれません。スカラー値関数: grad() はスカラー値関数の勾配を取り、スカラー/ベクトルをスカラー関数にマッピングします。ベクトル値関数もあります。ベクトルをベクトルにマッピングするベクトル値関数の場合、勾配の類似物はヤコビ行列です。 jacfwd() および jacrev() を使用すると、JAX は、ドメイン内のポイントで評価されるとヤコビ行列を生成する関数を返します。

ディープラーニングの観点から見ると、JAX はヘッセ行列の計算を非常にシンプルかつ効率的にします。 XLA のおかげで、JAX は PyTorch よりも高速にヘッセ行列を計算でき、AdaHessian などの高次最適化の実装が大幅に高速化されます。

次のコードは、PyTorch での単純な入力合計のヘッセ行列です。​

ご覧のとおり、上記の計算には約 16.3 ミリ秒かかります。同じ計算を JAX で試してみましょう。

JAX を使用すると、計算には 1.55 ミリ秒しかかからず、PyTorch よりも 10 倍以上高速です。JAXはヘッセ行列を非常に高速に計算できるため、高次の最適化がより実現可能になります。

vmap() を使用した自動ベクトル化

JAX の API には、vmap() 自動ベクトル化という別の変換があります。以下はベクトル化されたベクトル加算のデモンストレーションです:​

pmap() を使用した自動並列化

分散コンピューティングは、特にディープラーニングにおいてますます重要になってきており、下の図に示すように、SOTA モデルが非常に大規模に成長しています。

​JAX は XLA のおかげでアクセラレータ上で簡単に計算できますが、複数のアクセラレータを使用して簡単に計算することもできます。つまり、pmap() という単一のコマンドで SPMD プログラムの分散トレーニングを実行できます。

ベクトル行列乗算を例に挙げてみましょう。以下は非並列ベクトル行列乗算です。


JAX を使用すると、操作を pmap() でラップするだけで、これらの計算を 4 つの TPU に簡単に分散できます。これにより、ユーザーは各 TPU で同時にドット積を実行できるようになり、計算速度が大幅に向上します (大規模な計算の場合)。


jit() を使用して関数を高速化する

JIT コンパイルは、解釈と AoT (事前) コンパイルの中間にあるコード実行方法です。重要なのは、JIT コンパイラは、最初の実行が遅くなることを犠牲にして、実行時にコードを高速な実行可能ファイルにコンパイルすることです。

JIT は、操作を 1 つずつ GPU カーネルにディスパッチする代わりに、XLA を使用して一連の操作を単一のカーネルにコンパイルし、エンドツーエンドでコンパイルされた関数の効率的な XLA 実装を提供します。

たとえば、次の図では、コードは 5000 x 5000 行列を 3 つの方法で計算する関数を定義しています。1 回は NumPy を使用し、1 回は JAX を使用し、もう 1 回は関数の JIT コンパイル バージョンで JAX を使用します。まずCPUで実験を行います。

JAX は、特に JIT を使用する場合、要素ごとの計算が大幅に高速化します。

JAX は NumPy よりも 2.3 倍以上高速であることがわかります。また、関数を JIT すると、JAX は NumPy よりも 30 倍高速になります。これらの結果はすでに印象的ですが、さらに進んで、JAX に TPU 上で計算を実行させてみましょう。

JAX が同じ計算を TPU で実行すると、相対的なパフォーマンスがさらに向上します (NumPy 計算は TPU 計算をサポートしていないため、CPU 上で実行されます)。この場合、JAX は NumPy よりも驚異的に 13 倍高速であることがわかります。関数と計算の両方を TPU で JIT すると、JAX は NumPy よりも 80 倍高速であることがわかります。

もちろん、この大幅な速度向上には代償が伴います。 JAX は JIT で許可される関数に制限を設けていますが、一般的には上記の NumPy 操作を含む関数のみが許可されます。さらに、Python 制御フローによる JIT にはいくつかの制限があるため、関数を記述するときにはそれを念頭に置いてください。

2022年ですが、JAXを使うべきでしょうか?

残念ながら、この質問に対する答えはやはり「それは状況による」です。 JAX に移行するかどうかは、状況と目標によって異なります。 2022 年に JAX を使用する必要があるかどうか (または使用すべきでないかどうか) を詳しく確認できるように、以下のフローチャートには推奨事項がまとめられており、関心領域ごとに異なるチャートが用意されています。​

科学計算

一般的なコンピューティングのための JAX に興味がある場合、最初に尋ねるべき質問は、アクセラレータで NumPy を実行しようとしているだけなのかということです。答えが「はい」であれば、当然 JAX への移行を開始する必要があります。

単に数字を計算するだけでなく、動的な計算モデリングに携わっている場合、JAX を使用するかどうかは具体的な使用例によって異なります。作業のほとんどがカスタム コードを多く使用して Python で行われる場合、ワークフローを強化するために JAX の学習を開始する価値があるかもしれません。

作業のほとんどが Python ではなく、何らかのハイブリッド モデルベース/ニューラル ネットワーク システムを構築したい場合は、JAX を使用する価値があるかもしれません。

ほとんどの作業を Python で行わない場合、または研究用の特殊なソフトウェア (熱力学、半導体など) を使用している場合は、それらのプログラムからデータをエクスポートしてカスタム計算を実行する場合を除き、JAX は適切なツールではない可能性があります。あなたの関心分野が物理学/数学に近く、計算手法(力学システム、微分幾何学、統計物理学)に関係し、仕事のほとんどが Mathematica などで行われている場合、特に既に大規模なカスタム コード ベースがある場合は、現在のツールを使い続ける価値があるかもしれません。​

ディープラーニング

JAX はディープラーニング用に構築された汎用フレームワークではないことを強調しましたが、JAX は高速で自動微分化機能を備えているため、ディープラーニングに JAX を使用するとどのような感じになるのか疑問に思われるかもしれません。

TPU でトレーニングを行う場合、特に現在 PyTorch を使用している場合は、JAX の使用を開始する必要があります。 PyTorch-XLA も存在しますが、TPU トレーニングには JAX を使用する方が間違いなく優れています。 SDE-Nets などの「非標準」アーキテクチャ/モデリングに取り組んでいる場合は、ぜひ JAX も試してみてください。また、高度な最適化技術を活用したい場合は、JAX を試してみるとよいでしょう。

特別なアーキテクチャを構築しておらず、GPU 上で一般的なアーキテクチャをトレーニングしているだけの場合は、今のところ PyTorch または TensorFlow を使い続けるほうがよいでしょう。ただし、このアドバイスは今後 1 ~ 2 年で急速に変化する可能性があります。 PyTorch は依然として研究分野で主流を占めていますが、JAX を使用した論文の数は着実に増加しています。 DeepMind や Google などの大手企業が JAX 向けの高度なディープラーニング API の開発を継続しているため、JAX は数年のうちに爆発的な成長率を達成する可能性があります。

つまり、特に研究者であれば、JAX についてある程度は理解しておく必要があります。​

初心者のためのディープラーニング


しかし、もし私が初心者だったらどうでしょうか?状況は少し変わります。

ディープラーニングについて学び、いくつかのアイデアを実装することに興味がある場合は、JAX または PyTorch を使用する必要があります。ディープラーニングを徹底的に学びたい場合、または Python ソフトウェアの経験がある場合は、PyTorch から始める必要があります。ディープラーニングを基礎から学びたい場合、または数学のバックグラウンドがある場合は、JAX が直感的にわかるかもしれません。この場合、大規模なプロジェクトに着手する前に、JAX の使用方法を理解しておく必要があります。

ディープラーニングに興味があり、関連する職種に転職したい場合は、PyTorch または TensorFlow を使用する必要があります。両方のフレームワークに精通することが最善ですが、さまざまなフレームワークの求人数からもわかるように、TensorFlow は「業界」フレームワークとして広く認識されていることを知っておく必要があります。

数学やソフトウェアの知識がないままディープラーニングを学びたい初心者であれば、JAX は使用しないでください。代わりに、Keras の方が良い選択です。

JAX を使うべきではない 4 つの理由

JAX の利点の多くは上で説明しましたが、JAX にはユーザー アプリケーションのパフォーマンスを大幅に向上させる可能性があります。しかし、著者は JAX を使用すべきでない理由として、次の 4 つも挙げています。

  • ​JAX はまだ公式には実験的なフレームワークとみなされています。 JAX は比較的「新しい」プロジェクトです。現在、JAX は本格的な Google 製品ではなく研究プロジェクトとみなされているため、ユーザーは JAX への移行を検討している場合、この点に留意する必要があります。
  • JAX を慎重に使用してください。デバッグにかかる​​時間コスト、さらには追跡されない副作用のリスクにより、関数型プログラミングをしっかりと理解していないユーザーは JAX を使用できない可能性があります。本格的なプロジェクトで JAX を使い始める前に、JAX の使用における一般的な落とし穴を理解しておく必要があります。
  • JAX は CPU 計算用に最適化されていません。 JAX は「アクセラレータ ファースト」方式で開発されているため、各操作のディスパッチは JAX に対して完全に最適化されていません。場合によっては、特に小さなプログラムの場合、JAX によって生じるオーバーヘッドのため、NumPy は実際に JAX よりも高速になることがあります。
  • JAX は Windows と互換性がありません。 JAX は現在 Windows ではサポートされていません。 Windows を使用していても JAX を試してみたい場合は、Colab を使用するか、仮想マシン (VM) にインストールすることができます。​

<<:  顔認識技術: スマートシティのためのスマートなソリューション

>>:  言葉を発することなくSiriに命令しましょう!清華大学の卒業生が「無言言語認識」ネックレスを開発

ブログ    
ブログ    

推薦する

...

Google のロボット工学プログラムは度重なる失敗からどのような教訓を得たのでしょうか?

Google は再びロボットの製造を開始する予定です。 。 。このニュースを伝えたとき、私は Go...

ディープラーニングの簡単な歴史: TF と PyTorch の独占、次の 10 年間の黄金時代

過去 10 年間で、機械学習 (特にディープラーニング) の分野では多数のアルゴリズムとアプリケーシ...

「AI+教育」の試行錯誤に誰がお金を払うのか?

「AI+教育」の導入は簡単? 2016年はAI(人工知能)元年と言われています。この年、Alpha...

Googleはプライバシーポリシーを更新し、インターネット上の公開データをAIの訓練に利用していることを明確にした。

7月6日、Googleはプライバシーポリシーを更新し、BardやCloud AIなどのさまざまな人...

人工知能開発の現状と将来動向の分析

人工知能、またはよく「AI」(英語の正式名称:Artificial Intelligence)と呼ば...

...

Langchain、ChromaDB、GPT 3.5 に基づく検索強化型生成

翻訳者|朱 仙中レビュー | Chonglou概要:このブログでは、検索拡張生成と呼ばれるプロンプト...

AIは金融犯罪者と戦う技術である

犯罪の手法がより巧妙になるにつれ、マネーロンダリングとの戦いは世界中のすべての金融機関にとって大きな...

清華大学人工知能開発報告:中国は過去10年間のAI特許出願で世界第1位

ザ・ペーパー記者 張偉最新の報告書によると、中国の人工知能特許出願件数は過去10年間で世界第1位であ...

何年も救助ステーションに取り残されていた彼らは、顔認識技術によって愛する人を見つけることができた。

2年前、アンディ・ラウとジン・ボーランが主演した映画「恋の迷宮」は、数え切れないほどのファンを映画...

...

...

MITの中国人博士課程学生がChatGPTをJupyterに移行し、自然言語プログラミングをワンストップソリューションに

自然言語プログラミングは Jupyter で直接実行できます。 MIT の中国人博士課程の学生によっ...

...