スタンフォード大学の博士が独力で注意力を​​9倍に加速! FlashAttention はビデオメモリを消費し、Transformer のコンテキストの長さが劇的に増加します

スタンフォード大学の博士が独力で注意力を​​9倍に加速! FlashAttention はビデオメモリを消費し、Transformer のコンテキストの長さが劇的に増加します

超高速かつメモリを節約するアテンション アルゴリズム FlashAttention の人気を受けて、アップグレードされたバージョン 2 が登場しました。

FlashAttention-2 は、近似値を使用せずに注意を高速化し、メモリ フットプリントを削減するためにゼロから作成されたアルゴリズムです。

第一世代と比較すると、FlashAttention-2 は 2 倍高速です。

実際、PyTorch の標準アテンションよりも最大 9 倍高速に実行されます。

1年前、スタンフォードAIラボの博士Tri Daoは、注目度を2~4倍高速化するFlashAttentionをリリースしました。現在、FlashAttentionは多くの企業や研究室に採用されており、ほとんどのLLMライブラリで広く使用されています。

現在、長いドキュメントクエリやストーリーライティングなどの新しいユースケースのニーズにより、大規模言語モデルのコンテキストは以前よりもはるかに長くなっています。GPT-4 のコンテキスト長は 32k、MosaicML の MPT のコンテキスト長は 65k、Anthropic の Claude のコンテキスト長は 100k です。

ただし、Transformer のコンテキスト長を拡大することは、その中心にある注意層の実行時間とメモリ要件が入力シーケンスの長さの 2 乗であるため、非常に困難です。

Tri Dao は FlashAttention-2 に取り組んできました。これは v1 より 2 倍、標準 Attention より 5 ~ 9 倍高速で、A100 で 225 TFLOP/s のトレーニング速度を達成しました。

写真

論文アドレス: https://tridao.me/publications/flash2/flash2.pdf

プロジェクトアドレス: https://github.com/Dao-AILab/flash-attention

FlashAttention-2: より優れたアルゴリズム、並列処理、作業分割

最大 225 TFLOP/s での GPT モデルのエンドツーエンドのトレーニング

FlashAttention はリリース時点で最適化されたベースラインよりも 2 ~ 4 倍高速ですが、まだかなりの改善の余地があります。

たとえば、FlashAttention はまだ最適化された行列乗算 (GEMM) 演算ほど高速ではなく、理論上の最大 FLOP/秒 (たとえば、A100 GPU では 124 TFLOP/秒) の 25 ~ 40% にしか達しません。

写真

GEMMが畳み込みにどのように使用されるか

研究者たちは過去数か月間、第 1 世代よりも強力なパフォーマンス メトリックを備えた FlashAttention-2 を開発してきました。

研究者らは、第 2 世代は NVIDIA の CUTLASS 3.x とそのコア ライブラリ CuTe を使用して、ゼロから完全に書き直すことに相当すると述べています。速度の面では、FlashAttention-2 は以前のバージョンより 2 倍高速で、A100 GPU で最大 230 TFLOPs/s に達します。

GPT のような言語モデルをエンドツーエンドでトレーニングする場合、研究者は最大 225 TFLOP/秒 (モデルの FLOP 使用率 72%) のトレーニング速度を達成しました。

注目度の計算順序の変更

FlashAttention は、タイリングと再計算を使用してアテンション計算を並べ替え、計算を大幅に高速化し、シーケンス長のメモリ使用量を 2 次から 1 次まで削減するアルゴリズムであることがわかっています。

写真

研究者は、入力ブロックを HBM (GPU メモリ) から SRAM (高速キャッシュ) にロードし、ブロックに対してアテンションを実行し、HBM の出力を更新します。

大きな中間アテンション マトリックスは HBM に書き込まれないため、メモリの読み取り/書き込みが削減され、実行時間が 2 ~ 4 倍高速化されます。

次の図は、FlashAttention のフォワード パス ダイアグラムです。研究者は、タイリングとソフトマックス リスケーリングを通じてモジュール ベースで操作し、HBM からの読み取りや書き込みを回避しながら、近似なしで正しい出力を取得します。

写真

ただし、FlashAttention では、異なるスレッド ブロックと GPU 上のワープ間での作業の分割が最適ではないため、依然として非効率性の問題があり、その結果、占有率が低下したり、共有メモリの読み取りと書き込みが不必要になったりします。

非matmul FLOPの減少

研究者らは、FlashAttention アルゴリズムを調整することで、非 matmul FLOP の数を削減しました。これは非常に重要です。なぜなら、最新の GPU には、matmul を大幅に高速化する特殊なコンピューティング ユニット (Nvidia GPU のテンソル コアなど) が搭載されているからです。

たとえば、A100 GPU FP16/BF16 matmul の最大理論スループットは 312 TFLOPs/s ですが、非 matmul FP32 の理論スループットは 19.5 TFLOPs/s にすぎません。

さらに、非 matmul FLOP は matmul FLOP より 16 倍高価です。

したがって、スループットを高く維持するために、研究者は matmul FLOP にできるだけ多くの時間を費やしたいと考えています。

研究者らはまた、FlashAttention で使用されるオンライン ソフトマックス トリックを書き直し、出力を変えずに、再スケーリング操作、境界チェック、因果マスキング操作の回数を減らしました。

より優れた並列処理

FlashAttention v1 はバッチ サイズとパーツ数に基づいて処理を並列化します。研究者は、1 つのスレッド ブロックを使用して 1 つのアテンション ヘッドを処理し、合計で (batch_size * ヘッド数) 個のスレッド ブロックを使用しました。

写真

順方向処理(左の画像)では、研究者はワーカー(スレッド ブロック)を並列化し、各ワーカーがアテンション マトリックスの行ブロックの処理を担当するようにしました。逆方向プロセス(右)では、各ワーカーが注目行列の列ブロックを処理する。

各スレッド ブロックはストリーミング マルチプロセッサ (SM) 上で実行されます。たとえば、A100 GPU には 108 個の SM があります。このスケジューリングは、この数が大きい場合 (例: ≥ 80) に効果的です。この場合、GPU 上のほぼすべての計算リソースを効果的に使用できるためです。

シーケンスが長い場合(通常はバッチが小さいかヘッドが少ないことを意味します)、GPU 上のマルチプロセッサをより有効に活用するために、研究者はシーケンスの長さの次元に沿ってさらに並列化を行い、メカニズムの大幅な高速化を実現しました。

より良い作業分割

各スレッド ブロック内でも、研究者は異なるワープ (連携して動作する 32 個のスレッドのグループ) 間で作業をどのように分割するかを決定する必要があります。研究者は通常、スレッド ブロックごとに 4 本または 8 本のワープを使用します。パーティション分割スキームは次の図のようになります。

研究者らは FlashAttention-2 のこのパーティショニングを改良し、異なるワープ間の同期と通信の量を減らし、共有メモリの読み取り/書き込みを減らしました。

写真

各ブロックでは、FlashAttention は K と V を 4 つのワープに分割し、Q をすべてのワープからアクセスできるようにします。これは「スライス K」方式と呼ばれます。

ただし、すべてのワープが中間結果を共有メモリに書き込み、同期してから、中間結果を合計する必要があるため、これはあまり効率的ではありません。

これらの共有メモリの読み取り/書き込みにより、FlashAttention での前方伝播が遅くなります。

FlashAttention-2 では、研究者は Q を 4 つのワープに分割し、K と V はすべてのワープからアクセス可能な状態を維持しました。

各ワープは行列乗算を実行して QK^T のスライスを取得した後、それを共有 V スライスと乗算して対応する出力スライスを取得します。

この方法では、ワープが相互に通信する必要がありません。共有メモリの読み取りと書き込みの回数を減らすと、速度が向上します。

新機能: 最大256のヘッドディメンション、マルチクエリアテンション

FlashAttention は最大ヘッド寸法 128 のみをサポートします。ほとんどのモデルに適していますが、一部のモデルは除外されます。

FlashAttention-2 はヘッド次元 256 をサポートするようになりました。つまり、GPT-J、CodeGen、CodeGen2、Stable Diffusion 1.x などのモデルは FlashAttention-2 を使用して加速を実現し、メモリを節約できます。

v2 では、マルチクエリ アテンション (MQA) とグループ化クエリ アテンション (GQA) もサポートされています。

写真

GQA は、クエリ ヘッドのセットごとに単一のキーと値のヘッダーを共有し、マルチ ヘッドとマルチ クエリ アテンションの間を補間します。

これらはすべてアテンションのバリエーションであり、複数のクエリ ヘッドがキーと値の両方で同じヘッドを指すことで、推論中の KV キャッシュのサイズを削減し、推論スループットを大幅に向上させることができます。

注目度ベンチマーク


研究者らは、A100 80GB SXM4 GPU 上で、さまざまな設定 (因果マスクの有無、ヘッド次元 64 または 128) でのさまざまな注意方法の実行時間を測定しました。

写真

研究者らは、FlashAttention-2 が第 1 世代 (xformers ライブラリと Triton の他の実装を含む) よりも約 2 倍高速であることを発見しました。

FlashAttention-2 は、PyTorch の標準アテンション実装よりも最大 9 倍高速です。

写真

A100 GPU での前進 + 後進速度

研究者たちは、同じ実装を H100 GPU で実行するだけで (TMA や第 4 世代 Tensor Core などの新しいハードウェア機能を活用するための特別な命令を使用せずに)、最大 335 TFLOP/秒の速度を達成することができました。

写真

H100 GPU の前進+後進速度

GPT のようなモデルのエンドツーエンドのトレーニングに使用すると、FlashAttention-2 は A100 GPU で最大 225TFLOPs/s の速度を達成できます (モデルの FLOP 使用率は 72%)。

すでに高度に最適化された FlashAttention モデルと比較すると、エンドツーエンドの高速化がさらに 1.3 倍向上します。

写真

今後の仕事

2 倍高速になるということは、研究者が 8k コンテキスト モデルをトレーニングするのと同じコストで、16k コンテキスト長のモデルをトレーニングできることを意味します。これらのモデルは、長い本やレポート、高解像度の画像、音声、ビデオを理解できます。

同時に、FlashAttention-2 は既存モデルのトレーニング、微調整、推論も加速します。

研究者らは近い将来、協力関係を拡大し、FlashAttention をさまざまな種類のデバイス (H100 GPU、AMD GPU など) や新しいデータ タイプ (fp8 など) に幅広く適用できるようにする予定です。

次に、研究者らは、新しいハードウェア機能 (TMA、第 4 世代 Tensor コア、fp8 など) を使用できるように、H100 GPU 向けに FlashAttention-2 をさらに最適化する予定です。

FlashAttention-2 の低レベルの最適化と、ローカル、拡張、ブロックスパース アテンションなどの高レベルのアルゴリズムの変更を組み合わせることで、研究者はより長いコンテキストで AI モデルをトレーニングできるようになります。

研究者たちは、コンパイラー研究者と協力して、これらの最適化手法をプログラミングにさらに活用できることにも興奮しています。

著者について

Tri Dao 氏はスタンフォード大学でコンピューターサイエンスの博士号を取得しました。指導教官は Christopher Ré 氏と Stefano Ermon 氏でした。

ホームページによると、彼は2024年9月からプリンストン大学のコンピューターサイエンスの助教授に就任する予定だ。

写真

Tri Dao の研究対象は機械学習とシステムであり、効率的なトレーニングと長期的な環境に重点を置いています。

- 効率的なトランスフォーマーのトレーニングと推論 - 長距離メモリを備えたシーケンス モデル - コンパクトなディープラーニング モデルのための構造化スパース性。

トリ・ダオ氏が本日、生成型 AI スタートアップ企業 Together AI の主任科学者に正式に就任したことは特筆に値します。

写真

参考文献:

https://princeton-nlp.github.io/flash-atttention-2/

<<:  トランスフォーマー後継モデル! MSRA が新しい大規模モデル インフラストラクチャを提案: 推論速度が 8 倍に向上し、メモリ使用量が 70% 削減

>>:  非常に少ないデータで大規模なモデルを微調整するにはどうすればよいでしょうか?

ブログ    
ブログ    

推薦する

Baidu Brain CVサービスでは、100~1000元のクーポンを提供しています。

覚えていますか? 「小都」はかつて「The Brain」の舞台でエネルギー溢れる出場者たちと競い合い...

...

...

どこにでもAI?小売業における 10 のエキサイティングな AI アプリケーション

[[311856]]小売業における当社の中核的な経験は、近年ほとんど変わっていません。店舗(またはオ...

PyTorchBigGraph を使用して超大規模グラフ モデルをトレーニングする方法は?

Facebook は、数十億のノードと数兆のエッジを持つグラフ モデルを効率的にトレーニングできる...

Google は NeRF を使用して、自動運転用の仮想世界でサンフランシスコを再現します

自動運転システムのトレーニングには、高精度のマップ、膨大な量のデータ、仮想環境が必要です。この方向で...

...

2024年の人工知能の6つの主要な発展トレンド

テクノロジーが支配する急速に進化する世界では、人間の創造性と人工知能 (AI) の魅力的な融合が中心...

GoogleはAIモデルのトレーニングのためだけに「アメリカ版Tieba」のデータを購入するのに6000万ドルを費やした!アルトマンは第3位の株主である

事件は解決しました!先週、Redditは、匿名の企業が同社のユーザーコンテンツにアクセスしてAIモデ...

...

コンピュータビジョンが日常生活をどう改善するか

機械学習の力を活用して日常のさまざまなタスクを処理するテクノロジーである人工知能は、すでに私たちの仕...

AIが起こした恐ろしいことは何ですか?

人工知能(AI)について話すとき、いつも恐怖を感じる人がいます。一体何を恐れているのですか?何か証拠...

金融業界がAI自動化を採用すべき理由

ガートナーによると、「ロボティック・プロセス・オートメーション(RPA)ソフトウェア市場は2020年...

人工知能プロジェクト: 注目すべき 7 つのポイント

最近、業界調査会社ガートナーは、AI プロジェクトの 85% は CIO に引き渡されないという大胆...