FlashAttention v2 は標準の Attention より 5 ~ 9 倍高速です。大規模なモデルで使用されます。

FlashAttention v2 は標準の Attention より 5 ~ 9 倍高速です。大規模なモデルで使用されます。

最近、GPT-4(コンテキスト長32k)、MosaicMLのMPT(コンテキスト長65k)、AnthropicのClaude(コンテキスト長100k)など、いくつかの長コンテキスト言語モデルがリリースされました。長いドキュメントのクエリやストーリーの作成などの新たなユースケースでは、言語モデルのコンテキスト ウィンドウを拡張する必要があることが示されています。

しかし、Transformer のコンテキスト長を拡大することは、そのコア アテンション レイヤーの時間と空間の複雑さが入力シーケンス長の 2 乗に比例するため、課題となります。

1年前、スタンフォード大学とニューヨーク州立大学バッファロー校の研究者が共同で、高速でメモリ効率に優れたアテンションアルゴリズム「FlashAttention」を提案しました。このアルゴリズムは、近似なしで注目度を高速化し、メモリ使用量を削減します。現在、多くの機関や研究室がトレーニングと推論を加速するために FlashAttention を採用しています

FlashAttention の概略図。

FlashAttention は最適化されたベースラインよりもすでに 2 ~ 4 倍高速ですが、まだかなりの改善の余地があります。 FlashAttention はまだ最適化された行列乗算 (GEMM) 演算ほど高速ではなく、理論上の最大 FLOPs/s の 25 ~ 40% にしか達しません。

現在、研究チームはFlashAttention-2の立ち上げを発表しました。 FlashAttention-2 は、Nvidia の CUTLASS 3.x とそのコア ライブラリ CuTe のプリミティブを使用して、完全にゼロから書き直されました。

写真

FlashAttention-2 は Tri Dao によって開発されました。彼はスタンフォード大学の博士課程の学生であり、Together.AIの主任科学者であり、2024年9月からプリンストン大学のコンピューターサイエンスの助教授として勤務する予定です。

FlashAttention-2 は FlashAttention の 2 倍の速度で、A100 GPU で 230 TFLOPs/s に達します。 GPT のような言語モデルをエンドツーエンドでトレーニングする場合、FlashAttention-2 は最大 225 TFLOP/秒 (モデル FLOP 使用率 72%) のトレーニング速度を達成できます。

FlashAttention-2 は、既存モデルのトレーニング、微調整、推論を加速します。つまり、同じコストでコンテキストの長さが 2 倍の言語モデルをトレーニングできるということです。これにより、言語モデルは長い本やレポート、高解像度の画像、音声、ビデオを理解できるようになります。

写真

  • プロジェクトアドレス: https://github.com/Dao-AILab/flash-attention
  • 技術レポート: https://tridao.me/publications/flash2/flash2.pdf

FlashAttentionとは何ですか?

FlashAttention は、アテンション計算を並べ替えるアルゴリズムです。タイリングや再計算などの従来の手法を使用して、計算速度を大幅に向上させ、シーケンス長を 2 次から 1 次まで減らし、メモリ使用量を削減します。タイリングとは、HBM (GPU メモリ) から SRAM (高速キャッシュ) に入力ブロックをロードし、そのブロックに対してアテンション操作を実行して、HBM の出力を更新することを意味します。

さらに、大きな中間アテンション マトリックスを HBM に書き込まないことで、メモリの読み取りと書き込みが削減され、クロック時間が 2 ~ 4 倍高速化されます。

次の図は、FlashAttention のフォワード パス ダイアグラムを示しています。研究者は、タイリングとソフトマックス リスケーリングを通じてブロック単位で操作し、HBM からの読み取り/書き込みを回避しながら、近似演算なしで正しい出力を取得しています。

写真

ただし、FlashAttention では、GPU 上の異なるスレッド ブロックとワープ間での作業の分割が最適ではないため、依然として非効率性が残っています。その結果、占有率が低下したり、共有メモリの読み取りと書き込みが不必要になったりします。

フラッシュアテンション-2

より優れたアルゴリズム、並列化、作業分割


非行列乗算フロップの減少

研究者らは、FlashAttention アルゴリズムを調整して、非行列乗算 (non-matmul) フロップの数を減らしました。これが重要なのは、最新の GPU には行列乗算を大幅に高速化する特殊な計算ユニット (Nvidia GPU のテンソル コアなど) があるためです。

たとえば、A100 GPU の理論上の最大スループットは、FP16/BF16 行列乗算では 312 TFLOPs/s ですが、非行列乗算 FP32 では 19.5 TFLOPs/s しかありません。

別の考え方としては、各非行列乗算 FLOP は行列乗算 FLOP よりも 16 倍高価であるということです。スループットを高く保つために、研究者は行列乗算 FLOP にできるだけ多くの時間を費やしたいと考えています。したがって、 FlashAttention で使用されるオンライン ソフトマックス トリックを書き直して、出力を変更せずに、再スケーリング操作、境界チェック、および因果マスキング操作の数を減らします

より優れた並列化

FlashAttention v1 はバッチ サイズとヘッド数に基づいて並列化されます。研究者は、1 つのスレッド ブロックを使用して 1 つのアテンション ヘッドを処理し、合計で (バッチ サイズ * ヘッド数) のスレッド ブロックを使用しました。各スレッド ブロックはストリーミング マルチプロセッサ (SM) 上で実行されるようにスケジュールされます。たとえば、A100 GPU にはこのような SM が 108 個あります。この数値が非常に大きい場合 (例: >= 80)、このスケジューリングは効果的であり、GPU 上のほぼすべてのコンピューティング リソースを効率的に使用できます。

長いシーケンスの場合(通常はバッチが小さいかヘッドの数が少ないことを意味します)、 GPU 上のマルチプロセッサをより有効に活用するために、研究者は現在、シーケンス長の次元にわたってさらに並列化を行い、メカニズムを大幅に高速化しています

より良い作業分割

各スレッド ブロック内でも、研究者は異なるワープ (連携して動作する 32 個のスレッドのグループ) 間で作業をどのように分割するかを決定する必要があります。通常、各スレッド ブロックは 4 つまたは 8 つのワープを使用し、パーティション分割スキームは次の図に示されています。

研究者らは、FlashAttention-2 のこのパーティショニングを改善し、異なるワープ間の同期と通信を削減し、共有メモリの読み取りと書き込みを削減しました

写真

FlashAttention は、各ブロックについて、K と V を 4 つのワープに分割し、Q をすべてのワープからアクセスできるようにします。これは「スライス K」方式と呼ばれます。ただし、この方式は、すべてのワープが中間結果を共有メモリに書き込み、同期してから中間結果を追加する必要があるため、非効率的です。これらの共有メモリの読み取りと書き込みにより、FlashAttention のフォワード パスが遅くなります。

FlashAttention-2 では、研究者はQ を 4 つのワープに分割し、K と V をすべてのワープからアクセスできるようにしました。各ワープは行列乗算を実行して QK^T のスライスを取得し、次にそれを V の共有スライスと単純に乗算して対応する出力スライスを取得します。ワープ間の通信は必要ありません。共有メモリの読み取りと書き込みを減らすことでも速度が向上します。

新機能: 最大256のヘッドディメンション、マルチクエリアテンション

FlashAttention はヘッド寸法が最大 128 までしかサポートしていないことが分かっています。これはほとんどのモデルに適していますが、一部のモデルは除外されます。

そのため、 FlashAttention-2 は最大 256 のヘッド次元をサポートしており、GPT-J、CodeGen および CodeGen2、StableDiffusion 1.x などのモデルは FlashAttention-2 を使用して高速化を実現し、メモリを節約できます

さらに、FlashAttention-2 は、マルチクエリアテンション (MQA) とグループクエリアテンション (GQA)もサポートします。これらは、複数のクエリ ヘッドが同じキーと値のヘッドに注意を向けることで推論中の KV キャッシュのサイズを削減し、推論スループットを大幅に向上させることができるアテンションのバリエーションです。

注目度ベンチマーク結果

研究者らは、A100 80GB SXM4 GPU 上で、さまざまな設定 (因果マスクなし/あり、ヘッド次元 64 または 128) でのさまざまなアテンション メソッドの実行時間を測定しました。

FlashAttention-2 は FlashAttention (および xformers ライブラリと Triton の他の実装) よりも 2 倍高速であることがわかりました。 PyTorch の標準的なアテンション実装と比較すると、FlashAttention-2 は最大 9 倍高速です。

A100 GPU での前進 + 後進の速度に注目してください。

さらに、研究者らは、同じ実装を H100 GPU で実行するだけで (TMA や第 4 世代 Tensor コアなどの新しいハードウェア機能を活用するための特別な命令を使用せずに)、最大 335 TFLOP/秒を達成しました。

H100 GPU での前進 + 後進の速度に注目してください。

エンドツーエンドの GPT のようなモデル トレーニングに使用すると、FlashAttention-2 は A100 GPU で最大 225 TFLOP/秒 (モデル FLOP 使用率 72%) を達成するのに役立ちます。十分に最適化された FlashAttention モデルと比較すると、エンドツーエンドで 1.3 倍の高速化が達成されます。

ここでのベースラインは FlashAttention のない Megatron-LM ですが、現在は FlashAttention を使用するオプションも備わっています。近い将来、 FlashAttention-2 も Megatron-LM に統合される予定です

研究チームは次のように述べています。「次のステップは、新しいハードウェア機能を使用するために、H100 GPU 用に FlashAttention-2 を最適化することです。」

<<:  ChatGPTコードインタープリターとJupyter Notebookを組み合わせてコーディング機能を強化

>>:  トランスフォーマー後継モデル! MSRA が新しい大規模モデル インフラストラクチャを提案: 推論速度が 8 倍に向上し、メモリ使用量が 70% 削減

ブログ    
ブログ    
ブログ    

推薦する

人工知能によって破壊される可能性のある7つの業界

[[417720]]人工知能は最先端の技術から人々の日常生活に組み込まれる技術へと急速に進化していま...

AIと建物の運用: 人、データ、信頼の基盤の構築

最近では、人工知能とそのサブセットである機械学習が注目のキーワードになっています。ディープフェイク、...

ハンズフリーロボットがゴミ分別の問題解決に役立つ

地球は私たちの共通の家であり、地球環境を保護するために私たちは協力しなければなりません。したがって、...

人工知能が防犯カメラの機能を強化している

今日、セキュリティという言葉を聞くと、それは通常、サイバーセキュリティ、特に人工知能に関するものにな...

...

人工知能に関する究極の議論: 私たちは AI なのか?

有名な科学者ホーキング博士の死からわずか半年後に、世界で最も聡明な科学者たちが歴史的な議論を始めると...

携帯電話の顔認識は本当に安全ですか?

​​​ [51CTO.com クイック翻訳]顔認識は、セキュリティメカニズムとして、ますます多くの携...

OpenAI、ChatGPTのトレーニングで何百万ものユーザー情報を盗んだとして訴訟

有名モデルChatGPTの進路に、ちょっとした紆余曲折が訪れ始めた。カリフォルニアに拠点を置く法律事...

企業におけるAIの応用は成熟段階に入ったのでしょうか?

マッキンゼーは、AI が多くの業務活動を自動化するという見通しに楽観的である一方で、あらゆる規模の自...

2022年の人工知能の7つのトレンド

近い将来に大きな価値を生み出す可能性のある技術の予測となると、人工知能は間違いなくリストのトップに位...

人工知能時代の罠を回避し、実装を実現する方法

つい最近、カリフォルニア大学バークレー校で活躍している、インターネットで有名な無人食品配達車「Kiw...

...

2か月でAIをゼロから学んだ方法とは?

編集者注: 人工知能は「電気」のようなものになりつつあり、その将来の発展に関心を持つ人は誰でもそれに...

AI企業は米国政府に安全性テストを報告することが義務付けられる

バイデン政権は、すべての主要なAIシステムの開発者にセキュリティテストの結果を政府に開示することを義...

人工知能は偏見の岐路に立っている

企業がより多くの機械学習や人工知能モデルを本番環境に導入するにつれて、システム内の偏りに対する認識が...