スタンフォード大学のコンピュータサイエンス博士による新しい研究: 新しいアテンションは 2 ~ 4 倍高速化、BERT シングルノードトレーニングは最速

スタンフォード大学のコンピュータサイエンス博士による新しい研究: 新しいアテンションは 2 ~ 4 倍高速化、BERT シングルノードトレーニングは最速

高速でメモリ効率に優れたアテンション アルゴリズム、FlashAttention がここにあります。 GPU メモリの読み取り/書き込みを削減することで、FlashAttention は PyTorch 標準アテンションよりも 2 ~ 4 倍高速に実行され、必要なメモリは 5 ~ 20 倍少なくなります。


この研究はスタンフォード大学とニューヨーク州立大学バッファロー校の研究者によって実施された。共同筆頭著者は、スタンフォード大学のコンピュータサイエンス博士課程の学生である Tri Dao 氏と Dan Fu 氏です。

以下に論文の具体的な内容を紹介します。

フラッシュアテンション

Transformer は、自然言語処理や画像分類などのアプリケーションで最も広く使用されているアーキテクチャになりました。研究が進むにつれて、Transformer のサイズはより大きく、より深くなりましたが、Transformer のコアとなる自己注意モジュールの時間計算量とメモリ計算量はシーケンスの長さの 2 乗であるため、Transformer に長いコンテキストを装備することは依然として困難です。

一部の研究者は、注意計算とメモリ要件を削減するための近似注意方法をいくつか提案しています。これらの方法には、スパース近似、低ランク近似、およびそれらの組み合わせが含まれます。これらの方法は、シーケンスの長さに関して計算を線形またはほぼ線形に削減できますが、標準的なアテンションに比べてウォールクロックの高速化が見られないため、広く使用されていません。その主な理由の 1 つは、これらの研究が FLOP (実時間速度とは関係ない可能性がある) の削減に重点を置いており、メモリ アクセス (IO) によるオーバーヘッドを無視する傾向があることです。

この論文では、アテンション アルゴリズムは IO 対応、つまりビデオ メモリ レベル間の読み取りと書き込みを考慮して作成する必要があると主張しています。最新の GPU はメモリ速度を超える計算速度を備えており、トランスフォーマーのほとんどの操作はメモリアクセスによってブロックされます。 IO 対応アルゴリズムは、データベース結合、画像処理、数値線形代数など、データの読み取りと書き込みが実行時間の大部分を占める同様のメモリバインド操作にとって重要です。ただし、PyTorch や Tensorflow などのディープラーニング用の一般的な Python インターフェースでは、メモリアクセスを細かく制御することはできません。

論文アドレス: https://arxiv.org/pdf/2205.14135.pdfGitHub アドレス: https://github.com/HazyResearch/flash-attention

この研究では、より少ないメモリアクセスで正確な注意を計算できる新しい注意アルゴリズム、FlashAttention を提案します。 FlashAttention は、HBM (高帯域幅メモリ) からのアテンション マトリックスの読み取りと書き込みを回避することを目的としています。これには、(i) ソフトマックス削減が入力全体にアクセスせずに計算できること、および (ii) 中間注意行列が後方伝播中に保存できないことが必要です。

この研究では、これらの課題に対処するために、実証済みの 2 つの手法を使用しました。

(i) 入力をチャンクに分割し、入力チャンクに対して複数のパスを作成することでアテンション計算を再編成し、ソフトマックス削減(タイリングとも呼ばれる)を段階的に実行します。 (ii) フォワードパスからのソフトマックス正規化係数を保存し、バックワードパス中にオンチップでアテンションをすばやく再計算します。これは、HBMから中間アテンションマトリックスを読み取る標準的な方法よりも高速です。

この研究では、CUDA に FlashAttention を実装して、メモリ アクセスのきめ細かな制御を実現し、すべてのアテンション操作を単一の GPU カーネルに統合します。再計算により FLOP は増加しますが、HBM アクセス数が大幅に減少したことにより、実行速度が高速化 (GPT-2 で最大 7.6 倍、図 1 右) し、メモリ使用量も減少 (シーケンス長に比例) します。

この研究では、FlashAttentionのIO複雑度を分析し、𝑂(𝑁^2𝑑^2^𝑀−1)HBMアクセスが必要であることを証明します。ここで、𝑑はヘッド次元、𝑀はSRAMのサイズです。一方、標準的なアテンションでは、Ω(𝑁𝑑 + 𝑁^2)HBMアクセスが必要です。 𝑑 と 𝑀 の典型的な値の場合、FlashAttention では標準的なアテンションよりも HBM アクセスが大幅に少なくなります (図 2 に示すように、最大​​ 9 倍少なくなります)。さらに、この研究では、正確なアテンション アルゴリズムではすべての SRAM サイズに対して HBM アクセス数を漸近的に改善できないことを示す下限値を示しています。

この研究では、FlashAttention は、メモリ アクセスのオーバーヘッドの問題を克服することで、近似アテンション アルゴリズムを実装するためのプリミティブとして使用できることも示されました。概念実証として、この研究では、FlashAttention よりも 2 ~ 4 倍高速で、64k のシーケンス長まで拡張可能なスパース アテンション アルゴリズムである Block Sparse FlashAttention を実装しました。この調査では、Block-Sparse FlashAttention の方が FlashAttention よりも IO 複雑度が優れていることが実証されています。

この研究では FlashAttention もオープンソース化されたことは特筆に値します。

実験結果

BERT: FlashAttention は、単一ノードの BERT トレーニング速度で最速を実現します。この研究では、Wikipedia で FlashAttention を使用して BERT-large モデルをトレーニングしました。表 1 は、FlashAttention のトレーニング時間と Nvidia MLPerf 1.1 を比較したもので、FlashAttention のトレーニングの方が 15% 高速であることがわかります。

GPT-2: 表2は、FlashAttentionがHuggingFaceと比較して最大3倍、Megatron-LMと比較して最大1.7倍のエンドツーエンドの高速化を達成できることを示しています。

長距離アリーナ: この研究では、長距離アリーナ (LRA) ベンチマークで実験を行い、精度、スループット、トレーニング時間を測定しました。各タスクのシーケンスの長さは 1024 ~ 4096 の範囲で異なります。さらに、実験は Tay と Xiong らによる実験設定に従います。表 3 は、FlashAttention が標準の注意より 2.4 倍高速であることを示しています。ブロックスパース FlashAttention は、すべての近似アテンション メソッドよりも高速です。

長いコンテキストを持つ言語モデル: FlashAttention のランタイムとメモリ効率により、Megatron-LM よりも高速に実行しながら、GPT-2 のコンテキスト長を 4 倍に増やすことができます。表 4 からわかるように、コンテキスト長が 4K の FlashAttention GPT-2 は、コンテキスト長が 1K の Megatron の GPT-2 よりも 30% 高速であり、パープレキシティは 0.7 改善されています。

表 5 は、MIMIC では、シーケンス長 16K のパフォーマンスが長さ 512 のパフォーマンスよりも 4.3 ポイント高いのに対し、ECtHR では、シーケンス長 8K のパフォーマンスが長さ 512 のパフォーマンスよりも 8.5 ポイント高いことを示しています。

表 6 は、Transformer モデルが Path-X 問題と Path-256 問題を解決できることを示しています。この研究では、Path-64 でトランスフォーマーを事前トレーニングし、空間補間位置埋め込みを通じて Path-X に移行しました。 FlashAttention は Path-X で 61.4% の精度を達成します。さらに、ブロックスパース FlashAttention により、Transformer は 64K シーケンスに拡張でき、Path-256 で 63.1% の精度を達成できます。

図 3 (左) は、ベースラインと比較した FlashAttention および Block-Sparse FlashAttention の順方向 + 逆方向伝播の実行時間をミリ秒単位で報告しています。また、図 3 (右) は、さまざまな正確な、近似した、およびスパースなアテンション ベースラインと比較した FlashAttention および Block-Sparse FlashAttention のメモリ使用量を示しています。

<<:  ジェフ・ディーンらの新しい研究:言語モデルを別の視点から見る:規模が十分でなければ発見されない

>>:  2022 RPA認定ランキング

ブログ    
ブログ    
ブログ    

推薦する

AIが顧客体験を変革する10の方法

今日、消費者はオンライン小売業者に対して非常に高い期待を抱いています。多くの場合、顧客のショッピング...

...

再トレーニングなしでモデルを6倍圧縮:数学者チームが新しい量子化法を提案

RUDN大学の数学者チームは、再トレーニングに余分なリソースを費やすことなく、ニューラルネットワーク...

人工知能が遠隔患者ケアに革命を起こす

パンデミックにより、遠隔患者ケアのための人工知能(AI)の進歩が加速した。医師は、デジタル患者モニタ...

AIとビッグデータ2017「成長痛」

2017 年、人工知能とビッグデータの開発では次の 10 の成長痛が発生しました。 [[21567...

2018 年に人工知能を変える 5 つのビッグデータ トレンド

ビッグデータや人工知能の広範な導入を通じて、これらの新興技術の大きな影響が世界経済に浸透するにつれ、...

AI は従業員トレーニングにどのような革命をもたらすのでしょうか?

[[395608]]スキルギャップを埋めるプレッシャーの下、多くの組織が人工知能テクノロジーを導入...

毎日のアルゴリズム: 二分木のレベルトラバーサル

[[423982]]バイナリ ツリーが与えられた場合、そのノード値のボトムアップ レベルのトラバーサ...

...

サイバーセキュリティのための AI: セキュリティ戦略への AI の組み込み

人工知能は、生産性の向上、売上の増加、ユーザーエクスペリエンスの向上など、さまざまな状況で使用されて...

AlphaGoの仕組み:マルチエージェント強化学習の詳細な説明

このレビュー記事では、著者はマルチインテリジェンス強化学習の理論的基礎を詳細に紹介し、さまざまなマル...

行動バイオメトリクスと機械学習が顧客関係を改善する方法

行動バイオメトリクスは、トラブルのない認証を実現し、世界中の消費者の体験に革命をもたらす画期的なテク...

カリフォルニア工科大学、プロペラアームを使って滑空する二足歩行ロボットを開発

LEONARDO は、カリフォルニア工科大学の航空宇宙ロボット工学および制御研究所の言語の天才たちの...

ブリッジで人間の世界チャンピオン8人が全員AIに負ける

最近、人工知能(AI)が再び人間に勝利しました。今回、人工知能はチェッカーやチェス、囲碁をプレイせず...