PyTorch チームが「すべてを分割」モデルを書き直し、元の実装より 8 倍高速化

PyTorch チームが「すべてを分割」モデルを書き直し、元の実装より 8 倍高速化

今年初めから現在に至るまで、生成AIは急速に発展してきました。しかし、多くの場合、特に PyTorch を使用する場合、生成 AI のトレーニングと推論をどのように高速化するかという難しい問題に直面します。

この記事では、PyTorch チームの研究者が解決策を提示しています。この記事では、純粋なネイティブ PyTorch を使用して生成 AI モデルを高速化する方法に焦点を当てています。さらに、この記事では、新しい PyTorch 機能と、これらの機能を組み合わせる実用的な例を紹介します。

結果はどうなりましたか? PyTorch チームは、Meta の「Segment Everything」(SAM) モデルを書き直し、ネイティブの PyTorch 最適化を使用して、精度を損なうことなくコードを元の実装より 8 倍高速化したと述べています。

ブログアドレス: https://pytorch.org/blog/accelerating-generative-ai/

この記事を読むと、次のことがわかります。

  • Torch.compile: PyTorch モデル コンパイラ。PyTorch 2.0 では、torch.compile() という新しい関数が追加され、1 行のコードで既存のモデルを高速化できます。
  • GPU 量子化: 計算精度を下げることでモデルを高速化します。
  • SDPA (Scaled Dot Product Attention): メモリ効率の高いアテンション実装。
  • 半構造化 (2:4) スパース性: GPU 向けに最適化されたスパース メモリ形式。
  • ネストされたテンソル: ネストされたテンソルは、{テンソル、マスク} をまとめて、異なるサイズの画像など、サイズが均一でないデータを 1 つのテンソルにまとめます。
  • Triton カスタム Ops: Triton Python DSL を使用して GPU 操作を記述し、カスタム演算子の登録を通じて PyTorch のさまざまなコンポーネントに簡単に統合します。

PyTorch ネイティブ機能により、スループットが向上し、メモリ オーバーヘッドが削減されます。

SAM は Meta によって提案されました。この研究の詳細については、「CV はもう存在しない? Meta が「すべてを分割する」AI モデルをリリース、CV が GPT-3 の瞬間を告げるかもしれない」を参照してください。

次に、この記事では、パフォーマンス分析、ボトルネックの特定、およびこれらの新機能を PyTorch に統合して SAM が直面する問題を解決する方法など、SAM の最適化プロセスを紹介します。さらに、この記事では、PyTorch のいくつかの新機能、torch.compile、SDPA、Triton カーネル、ネストされたテンソル、半構造化スパース性についても紹介します。

この記事は、さらに深く掘り下げています。記事の最後では、SAM の Express バージョンを紹介します。ご興味のあるパートナーは、GitHub からダウンロードできます。また、この記事では、Perfetto UI を通じてデータを視覚化し、PyTorch の各機能のアプリケーション価値を示しています。

GitHub アドレス: https://github.com/pytorch-labs/segment-anything-fast

Split Everything Model SAM の書き換え

調査によると、この記事で使用されている SAM ベースライン データ型は float32 dtype であり、バッチ サイズは 1 です。PyTorch Profiler を使用してカーネル トレースを表示した結果は次のとおりです。

この記事では、SAM を 2 つの場所で最適化できることがわかりました。

1 つ目は、aten::index への長い呼び出しです。これは、テンソルのインデックス操作 ([] など) によって行われた基礎となる呼び出しによって発生します。ただし、実際には、aten::index は 2 つのカーネルを起動するプロセスにあり、その間に cudaStreamSynchronize がブロックされるため、GPU が aten::index に費やす時間は比較的短くなります。つまり、CPU は 2 番目のカーネルが起動されるまで GPU の処理が完了するのを待機します。したがって、SAM を最適化するには、アイドル時間の原因となるブロッキング GPU 同期を排除する努力をすべきだと本論文では主張します。

2 つ目は、SAM が行列乗算 (上図の濃い緑色) に多くの GPU 時間を費やしていることです。これは、Transformer では一般的です。 SAM モデルが行列乗算に費やす GPU 時間を削減できれば、SAM を大幅に高速化できます。

次に、この論文では、SAM のスループット (img/s) とメモリ オーバーヘッド (GiB) を使用してベースラインを確立します。次に最適化プロセスが行われます。

Bfloat16 半精度 (GPU 同期とバッチ処理も含む)

上記の問題を解決するために、つまり行列の乗算にかかる時間を短縮するために、この論文では bfloat16 を使用します。 Bfloat16 は、各パラメータとアクティベーションの精度を下げることで計算時間とメモリを大幅に節約できる、よく使用される半精度型です。


パディングタイプをbfloat16に置き換える

さらに、GPU 同期を削除するために、この論文では最適化できる 2 つの場所を見つけました。


具体的には(上の図を参照するとわかりやすいですが、登場する変数名はすべてコード内にあります)、SAM の画像エンコーダーには、座標スケーラーとして機能する変数 q_coords と k_coords があり、CPU 上で割り当てられ、処理されることがわかりました。ただし、これらの変数が rel_pos_resized でインデックス付けに使用されると、これらのインデックス付け操作によってこれらの変数は自動的に GPU に移動され、このコピーによって GPU が同期されます。上記の問題を解決するために、この部分は、上に示すように torch.where を使用して書き直すことで問題を解決できることが研究で明らかになりました。

カーネルトレース

これらの変更を適用した後、特にバッチ サイズが小さい場合 (ここでは 1)、個々のカーネル呼び出し間に大きなギャップがあることに気付きました。この現象をより深く理解するために、この記事ではバッチ サイズ 8 での SAM 推論のパフォーマンス分析から始めます。

各カーネルに費やされた時間を見ると、SAM の GPU 時間の大部分が要素ごとのカーネルとソフトマックス演算に費やされていることがわかります。

これで、行列乗算の相対的なコストがはるかに小さくなることがわかります。

GPU 同期と bfloat16 最適化を組み合わせることで、SAM パフォーマンスが 3 倍向上しました。

Torch.compile (+グラフブレークとCUDAグラフ)

この論文では、SAM の詳細な研究の過程で多くの小さな操作があることがわかりました。コンパイラを使用して操作を融合すると大きなメリットがあると考えられているため、PyTorch は torch.compile に次の最適化を加えました。

  • nn.LayerNorm や nn.GELU などの一連の操作を単一の GPU カーネルに統合します。
  • 行列乗算カーネルの直後の演算を融合して、GPU カーネル呼び出しの数を減らします。

これらの最適化により、GPU グローバル メモリのラウンドトリップ回数が削減され、推論が高速化されました。これで、SAM の画像エンコーダーで torch.compile を試すことができます。パフォーマンスを最大化するために、この記事ではいくつかの高度なコンパイル手法を使用します。

カーネルトレース

結果は、torch.compile が適切に機能していることを示しています。

ソフトマックスが時間の大部分を占め、その後にさまざまな GEMM バリアントが続くことがわかります。以下は、バッチ サイズが 8 以上の変更を測定します。

SDPA: スケールド・ドット・プロダクト・アテンション

次に、本論文では注意機構に焦点を当ててSDPA(scaled_dot_product_attention)に関する実験を行った。一般に、ネイティブの注意メカニズムは、時間とメモリの両方において、シーケンスの長さに比例して増加します。 PyTorch の SDPA 操作は、Flash Attention、FlashAttentionV2、xFormer のメモリ効率の高いアテンション原則に基づいて構築されており、GPU アテンションを大幅に高速化できます。 torch.compile と組み合わせると、この操作により、MultiheadAttention のバリアント間で共通のパターンを表現し、融合できるようになります。小さな変更を加えると、モデルは scaled_dot_product_attention を使用できるようになります。

カーネルトレース

ここで、メモリ効率の高いアテンション カーネルが GPU 上で多くの計算時間を費やしていることがわかります。

PyTorch のネイティブ scaled_dot_product_attention を使用すると、バッチ サイズを大幅に増やすことができます。次の図は、バッチ サイズ 32 以上の場合の変更を示しています。

その後、この研究では、Triton、NestedTensor、バッチPredict_torch、int8量子化、半構造化(2:4)スパース性などの操作を試しました。

たとえば、この論文ではカスタムの位置 Triton カーネルを使用し、バッチ サイズ 32 で測定結果を観察します。

Nested Tensor を使用したバッチ サイズ 32 以上のバリエーション。

量子化を追加した後、バッチ サイズが 32 以上になると測定値が変化します。

この記事は半構造化スパース性で終わります。研究では、行列の乗算は依然として対処が必要なボトルネックであると述べている。解決策は、スパース化を使用して行列乗算を近似することです。行列をスパースにする(つまり、値をゼロにする)ことで、重みテンソルと活性化テンソルを格納するために使用するビット数を減らすことができます。この研究では、テンソルの重みをゼロに設定するプロセスを「プルーニング」と呼んでいます。小さな重みを削減すると、精度を大幅に損なうことなくモデルのサイズを縮小できる可能性があります。

剪定には、完全に非構造化のものから高度に構造化されたものまで、さまざまなアプローチがあります。非構造化プルーニングは理論的には精度に最小限の影響しか与えませんが、GPU は大規模な密行列の乗算には非常に効率的ですが、スパースの場合はパフォーマンスが大幅に低下する可能性があります。 PyTorch で最近サポートされているプルーニング手法は、半構造化 (または 2:4) スパース性と呼ばれるバランスを見つけようとします。このスパース ストレージは、元のテンソルを 50% 削減しながら、密なテンソル出力を生成します。詳細については下の図を参照してください。

このスパース ストレージ形式と関連する高速カーネルを使用するには、次に重みを削減する必要があります。この論文では、2:4 のスパース性で剪定を行うために、最小の 2 つの重みを選択します。デフォルトの PyTorch (「ストライド」) レイアウトからこの新しい半構造化スパース レイアウトに重みを変更するのは簡単です。 apply_sparse(model) を実装するには、32 行の Python コードだけが必要です。

スパース性が 2:4 の場合、vit_b とバッチ サイズ 32 の SAM パフォーマンスがピークに達します。

最後に、この記事を一言でまとめると、この記事では、これまでで最速の PyTorch での Segment Anything 実装を紹介します。この記事では、公式にリリースされた一連の新機能を利用して、精度を損なうことなく、元の SAM を純粋な PyTorch で書き直しました。

興味のある読者は、詳細については元のブログをご覧ください。

<<:  空軍の最高データ・AI責任者がAIを通じて戦略的優位性を獲得する方法について語る

>>: 

ブログ    

推薦する

...

...

テスラのヒューマノイドロボットが再び進化:視覚のみに基づいて物体を自律的に分類し、ヨガができる

数ヶ月沈黙していたテスラのヒューマノイドロボット、オプティマスプライムがついに新たな展開を見せた。私...

機械学習の卒業生は就職に不安を感じ始めています!卒業生と企業のどちらがより厳しいでしょうか?

機械学習を専攻する学生も就職について不安を感じ始めているのでしょうか?昨日、あるネットユーザーがRe...

COVID-19は非接触アクセス制御の新時代を加速させる

現在、新型コロナウイルス感染症のパンデミックが世界的に拡大し、私たちの知る世界は大きく変化しています...

...

将来を見据えたデータセキュリティのためのAIソリューション

今日、ビジネスリーダーは急速に進化するデジタル世界における多数のデータセキュリティの脅威に対処してい...

医療における人工知能:医師よりも正確

[[339138]]新しい医療用人工知能システムは、医師と同じように患者を診察することができます。画...

マイクロソフトとフェイスブックが共同で人工知能ソフトウェアを開発し、グーグルの主導的地位に挑戦

マイクロソフトはすでにオープンソースの人工知能ソフトウェアを持っています。しかしここ数カ月、マイクロ...

...

...

...

...

2020年におすすめの優れた人工知能システム

優れた AI システムは、企業に大きな競争上の優位性をもたらすことができます。理論的には、AI と機...