スタンフォード大学のコンピュータサイエンス博士による新しい研究: 新しいアテンションは 2 ~ 4 倍高速化、BERT シングルノードトレーニングは最速

スタンフォード大学のコンピュータサイエンス博士による新しい研究: 新しいアテンションは 2 ~ 4 倍高速化、BERT シングルノードトレーニングは最速

高速でメモリ効率に優れたアテンション アルゴリズム、FlashAttention がここにあります。 GPU メモリの読み取り/書き込みを削減することで、FlashAttention は PyTorch 標準アテンションよりも 2 ~ 4 倍高速に実行され、必要なメモリは 5 ~ 20 倍少なくなります。


この研究はスタンフォード大学とニューヨーク州立大学バッファロー校の研究者によって実施された。共同筆頭著者は、スタンフォード大学のコンピュータサイエンス博士課程の学生である Tri Dao 氏と Dan Fu 氏です。

以下に論文の具体的な内容を紹介します。

フラッシュアテンション

Transformer は、自然言語処理や画像分類などのアプリケーションで最も広く使用されているアーキテクチャになりました。研究が進むにつれて、Transformer のサイズはより大きく、より深くなりましたが、Transformer のコアとなる自己注意モジュールの時間計算量とメモリ計算量はシーケンスの長さの 2 乗であるため、Transformer に長いコンテキストを装備することは依然として困難です。

一部の研究者は、注意計算とメモリ要件を削減するための近似注意方法をいくつか提案しています。これらの方法には、スパース近似、低ランク近似、およびそれらの組み合わせが含まれます。これらの方法は、シーケンスの長さに関して計算を線形またはほぼ線形に削減できますが、標準的なアテンションに比べてウォールクロックの高速化が見られないため、広く使用されていません。その主な理由の 1 つは、これらの研究が FLOP (実時間速度とは関係ない可能性がある) の削減に重点を置いており、メモリ アクセス (IO) によるオーバーヘッドを無視する傾向があることです。

この論文では、アテンション アルゴリズムは IO 対応、つまりビデオ メモリ レベル間の読み取りと書き込みを考慮して作成する必要があると主張しています。最新の GPU はメモリ速度を超える計算速度を備えており、トランスフォーマーのほとんどの操作はメモリアクセスによってブロックされます。 IO 対応アルゴリズムは、データベース結合、画像処理、数値線形代数など、データの読み取りと書き込みが実行時間の大部分を占める同様のメモリバインド操作にとって重要です。ただし、PyTorch や Tensorflow などのディープラーニング用の一般的な Python インターフェースでは、メモリアクセスを細かく制御することはできません。

論文アドレス: https://arxiv.org/pdf/2205.14135.pdfGitHub アドレス: https://github.com/HazyResearch/flash-attention

この研究では、より少ないメモリアクセスで正確な注意を計算できる新しい注意アルゴリズム、FlashAttention を提案します。 FlashAttention は、HBM (高帯域幅メモリ) からのアテンション マトリックスの読み取りと書き込みを回避することを目的としています。これには、(i) ソフトマックス削減が入力全体にアクセスせずに計算できること、および (ii) 中間注意行列が後方伝播中に保存できないことが必要です。

この研究では、これらの課題に対処するために、実証済みの 2 つの手法を使用しました。

(i) 入力をチャンクに分割し、入力チャンクに対して複数のパスを作成することでアテンション計算を再編成し、ソフトマックス削減(タイリングとも呼ばれる)を段階的に実行します。 (ii) フォワードパスからのソフトマックス正規化係数を保存し、バックワードパス中にオンチップでアテンションをすばやく再計算します。これは、HBMから中間アテンションマトリックスを読み取る標準的な方法よりも高速です。

この研究では、CUDA に FlashAttention を実装して、メモリ アクセスのきめ細かな制御を実現し、すべてのアテンション操作を単一の GPU カーネルに統合します。再計算により FLOP は増加しますが、HBM アクセス数が大幅に減少したことにより、実行速度が高速化 (GPT-2 で最大 7.6 倍、図 1 右) し、メモリ使用量も減少 (シーケンス長に比例) します。

この研究では、FlashAttentionのIO複雑度を分析し、𝑂(𝑁^2𝑑^2^𝑀−1)HBMアクセスが必要であることを証明します。ここで、𝑑はヘッド次元、𝑀はSRAMのサイズです。一方、標準的なアテンションでは、Ω(𝑁𝑑 + 𝑁^2)HBMアクセスが必要です。 𝑑 と 𝑀 の典型的な値の場合、FlashAttention では標準的なアテンションよりも HBM アクセスが大幅に少なくなります (図 2 に示すように、最大​​ 9 倍少なくなります)。さらに、この研究では、正確なアテンション アルゴリズムではすべての SRAM サイズに対して HBM アクセス数を漸近的に改善できないことを示す下限値を示しています。

この研究では、FlashAttention は、メモリ アクセスのオーバーヘッドの問題を克服することで、近似アテンション アルゴリズムを実装するためのプリミティブとして使用できることも示されました。概念実証として、この研究では、FlashAttention よりも 2 ~ 4 倍高速で、64k のシーケンス長まで拡張可能なスパース アテンション アルゴリズムである Block Sparse FlashAttention を実装しました。この調査では、Block-Sparse FlashAttention の方が FlashAttention よりも IO 複雑度が優れていることが実証されています。

この研究では FlashAttention もオープンソース化されたことは特筆に値します。

実験結果

BERT: FlashAttention は、単一ノードの BERT トレーニング速度で最速を実現します。この研究では、Wikipedia で FlashAttention を使用して BERT-large モデルをトレーニングしました。表 1 は、FlashAttention のトレーニング時間と Nvidia MLPerf 1.1 を比較したもので、FlashAttention のトレーニングの方が 15% 高速であることがわかります。

GPT-2: 表2は、FlashAttentionがHuggingFaceと比較して最大3倍、Megatron-LMと比較して最大1.7倍のエンドツーエンドの高速化を達成できることを示しています。

長距離アリーナ: この研究では、長距離アリーナ (LRA) ベンチマークで実験を行い、精度、スループット、トレーニング時間を測定しました。各タスクのシーケンスの長さは 1024 ~ 4096 の範囲で異なります。さらに、実験は Tay と Xiong らによる実験設定に従います。表 3 は、FlashAttention が標準の注意より 2.4 倍高速であることを示しています。ブロックスパース FlashAttention は、すべての近似アテンション メソッドよりも高速です。

長いコンテキストを持つ言語モデル: FlashAttention のランタイムとメモリ効率により、Megatron-LM よりも高速に実行しながら、GPT-2 のコンテキスト長を 4 倍に増やすことができます。表 4 からわかるように、コンテキスト長が 4K の FlashAttention GPT-2 は、コンテキスト長が 1K の Megatron の GPT-2 よりも 30% 高速であり、パープレキシティは 0.7 改善されています。

表 5 は、MIMIC では、シーケンス長 16K のパフォーマンスが長さ 512 のパフォーマンスよりも 4.3 ポイント高いのに対し、ECtHR では、シーケンス長 8K のパフォーマンスが長さ 512 のパフォーマンスよりも 8.5 ポイント高いことを示しています。

表 6 は、Transformer モデルが Path-X 問題と Path-256 問題を解決できることを示しています。この研究では、Path-64 でトランスフォーマーを事前トレーニングし、空間補間位置埋め込みを通じて Path-X に移行しました。 FlashAttention は Path-X で 61.4% の精度を達成します。さらに、ブロックスパース FlashAttention により、Transformer は 64K シーケンスに拡張でき、Path-256 で 63.1% の精度を達成できます。

図 3 (左) は、ベースラインと比較した FlashAttention および Block-Sparse FlashAttention の順方向 + 逆方向伝播の実行時間をミリ秒単位で報告しています。また、図 3 (右) は、さまざまな正確な、近似した、およびスパースなアテンション ベースラインと比較した FlashAttention および Block-Sparse FlashAttention のメモリ使用量を示しています。

<<:  ジェフ・ディーンらの新しい研究:言語モデルを別の視点から見る:規模が十分でなければ発見されない

>>:  2022 RPA認定ランキング

ブログ    
ブログ    

推薦する

AIが宇宙飛行士の健康を宇宙で監視する方法

[[286902]] ▲ 火星探査機ロゼッタが光学スペクトル赤外線リモートイメージングシステム(OS...

人工知能は人々の日常の職業生活をどのように変えているのでしょうか?

[[280560]]世界が急速に発展する中、専門家は生産性と仕事の効率性の向上に努めなければなりま...

左手にビッグデータ、右手に人工知能。これらのプログラマーは、パンデミック中に何をしたのでしょうか?

今年初めの流行は、特にCOVID-19の非常に感染力が強い性質により、適切な免疫ワクチンがない中で原...

面接で使えるEslintのFix機能に隠されたアルゴリズムの質問

[[422353]] eslint が修正をサポートしていることはわかっています。--fix パラメ...

AIは自メディア記事の質を知っている。これがWeChatの自動評価アルゴリズムだ

セルフメディアの時代において、すべてのパブリックアカウントは、自分の記事をより多くの人に見てもらえる...

大規模機械学習のためのプログラミング手法、計算モデル、Xgboost および MXNet の事例

[[191977]]現在、機械学習のトレンドは、従来の方法のシンプルなモデル + 少量データ (手動...

...

FenyintaのCTO、張明氏:観光産業を深く掘り下げ、AI技術を使って異言語コミュニケーションの問題を解決する

[51CTO.comからのオリジナル記事] 1930年代初頭、フランスの科学者GBアルチュニは翻訳に...

予測分析の 4 つの業界における用途

[[436125]]画像ソース: https://pixabay.com/images/id-602...

米国は、中国のAIチップ量子の3つの主要分野への投資を制限する最新の大統領令に署名しました。大手メーカーが50億ドル相当のA800を緊急発注

水曜日、ホワイトハウスは大統領令に署名した。米国は、中国の半導体設計ソフトウェアや製造ハードウェアへ...

AI は旅行体験をどのように向上させることができるのでしょうか?

AI を活用した休暇は旅行の未来であり、かつては考えられなかったパーソナライズされた没入型の体験を...

...

マスクのロボットが進化した!新たなスキルが解き放たれ、エンドツーエンドのニューラルネットワークが実現

マスク氏のロボットの大いなる進化。 1年前に初めて舞台に立ったときは動きが少しぎこちなかったが、今で...

AIがあなたの仕事を奪わないと決めつけないでください。

すでに、いくつかの日常的または退屈な作業がロボットや自動化によって置き換えられていますが、それによっ...

パンデミックの中で、これらの16の業界は技術のアップグレードを緊急に必要としている

パンデミックはビジネスを混乱させ、場合によっては世界を停止させ、ほぼすべての業界が事業運営方法を再考...