RNN の効率は Transformer に匹敵し、Google は 2 つの新しいアーキテクチャをリリース: 同じ規模では Mamba よりも強力

RNN の効率は Transformer に匹敵し、Google は 2 つの新しいアーキテクチャをリリース: 同じ規模では Mamba よりも強力

今回、Google DeepMindは基本モデルに新たな動きを見せた。

リカレント ニューラル ネットワーク (RNN) は、ディープラーニングと自然言語処理の研究の初期に中心的な役割を果たし、Google 初のエンドツーエンドの機械翻訳システムなど、多くのアプリケーションで実用的な成功を収めてきたことが知られています。しかし、近年のディープラーニングと NLP では、多層パーセプトロン (MLP) とマルチヘッドアテンション (MHA) を組み合わせた Transformer アーキテクチャが主流となっています。

トランスフォーマーは実際には RNN よりも優れたパフォーマンスを実現しており、最新のハードウェアを活用する上でも非常に効率的です。大規模な Transformer ベースの言語モデルは、インターネットから収集された膨大なデータセットをトレーニングすることで、目覚ましい成功を収めています。

Transformer アーキテクチャは大きな成功を収めましたが、それでも欠点はあります。たとえば、グローバル アテンションの 2 次複雑性により、Transformer が長いシーケンスに効果的に拡張することは困難です。さらに、キー値 (KV) キャッシュはシーケンスの長さに応じて直線的に増加するため、推論中に Transformer の速度が低下します。このとき、シーケンス全体を固定サイズの隠れ状態に圧縮し、反復的に更新できる再帰型言語モデルが代替手段となります。しかし、Transformer を置き換えるには、新しい RNN モデルは、スケーラビリティの点で同等のパフォーマンスを示すだけでなく、同様のハードウェア効率も達成する必要があります。

Google DeepMind の最近の論文では、研究者らが RG-LRU レイヤーという新しいゲート線形再帰レイヤーを提案し、その周囲にマルチクエリ アテンション (MQA) に代わる新しい再帰ブロックを設計しました。

彼らはリカレント ブロックを使用して 2 つの新しいモデルを構築しました。1つは MLP とリカレント ブロックを組み合わせた Hawk モデルでもう 1 つは MLP とリカレント ブロックおよびローカル アテンションを組み合わせた Griffin モデルです

  • 論文タイトル: Griffin: 効率的な言語モデルのためのゲート線形回帰とローカルアテンションの混合
  • 論文リンク: https://arxiv.org/pdf/2402.19427.pdf

研究者らは、Hawk と Griffin は、Transformers で以前に観察されたように、ホールドアウト損失とトレーニング FLOP の間で 70 億パラメータまでのべき乗則スケーリングを示すと述べています。その中でも、Griffin はすべてのモデル サイズで強力な Transformer ベースラインよりもわずかに低い保持損失を実現します。

研究者らは、さまざまなモデル サイズで 3000 億トークンを使用して Hawk と Griffin を過剰トレーニングし、トークンの数が半分しかないにもかかわらず、ダウンストリーム タスクで Hawk-3B が Mamba-3B よりも優れたパフォーマンスを発揮することを示しました。 Griffin-7B と Griffin-14B は、トークン数が 1/7 しか使用されていないにもかかわらず、Llama-2 と同等のパフォーマンスを達成しています。

さらに、Hawk と Griffin は、TPU-v3 上の Transformer と同等のトレーニング効率を実現します。対角 RNN 層はメモリに制約があるため、研究者はこれを実現するために RG-LRU 層のカーネルを使用しました。

同時に、推論中、Hawk と Griffin はどちらも MQA Transformer よりも高いスループットを実現し、長いシーケンスをサンプリングする際のレイテンシが低くなります。 Griffin は、トレーニング中に観察されたものよりも長いシーケンスで評価された場合、Transformer よりも優れたパフォーマンスを発揮し、トレーニング データからコピー タスクと検索タスクの両方を効果的に学習できます。ただし、事前トレーニング済みモデルを微調整なしでコピーおよび正確な検索タスクで評価すると、Hawk と Griffin のパフォーマンスは Transformers よりも悪くなります。

共同筆頭著者であり、DeepMind の研究科学者である Aleksandar Botev 氏は、ゲート線形ループとローカル アテンションを組み合わせたモデルである Griffin は、RNN の高効率の利点と Transformer の表現力をすべて保持し、最大 14B のパラメーター スケールまで拡張できると述べています。

出典: https://twitter.com/botev_mg/status/1763489634082795780

グリフィンモデルアーキテクチャ

すべてのグリフィンモデルには、(i)残差ブロック、(ii)MLPブロック、(iii)時間混合ブロックというコンポーネントが含まれています。 (i) と (ii) はすべてのモデルで同じですが、グローバルマルチクエリアテンション (MQA)、ローカル (スライディングウィンドウ) MQA、および提案された再帰ブロックの 3 つの時間ハイブリッドブロックがあります。研究者らは、リカレント ブロックの一部として、線形リカレント ユニットにヒントを得た新しいタイプのリカレント レイヤーである Realistic Gated Linear Recurrent Unit (RG-LRU) を使用しました。

図 2(a) に示すように、残差ブロックは、プレノルム トランスフォーマーにヒントを得た Griffin モデルのグローバル構造を定義します。研究者は入力シーケンスを埋め込んだ後、それを𝑁(𝑁はモデルの深さを表す)などのブロックに渡し、RMSNormを適用して最終的なアクティベーションを生成します。トークンの確率を計算するために、最終的な線形レイヤーとそれに続くソフトマックスが適用されます。このレイヤーの重みは入力埋め込みレイヤーと共有されます。

Transformerに匹敵するスケーリング効率を持つリカレントモデル

スケーリング研究は、モデルのハイパーパラメータを調整する方法と、スケーリング時のモデルの動作に関する重要な洞察を提供します。

本研究で評価するモデルを定義し、70 億パラメータまでのスケーリング曲線を提供し、下流タスクにおけるモデルのパフォーマンスを評価します。

彼らは3つのモデルファミリーを検討しました: (1) MQA-Transformerベースライン、(2) Hawk:純粋なRNNモデル、(3) Griffin:再帰ブロックとローカルアテンションを組み合わせたハイブリッドモデル。さまざまなモデル サイズの主要なモデル ハイパーパラメータは、付録 C で定義されています。

Hawk アーキテクチャは、Transformer ベースラインと同じ残差モデルと MLP ブロックを使用しますが、研究者は MQA の代わりに RG-LRU レイヤーを備えた再帰ブロックを時間的ハイブリッド ブロックとして使用しました。彼らは、リカレントブロックの幅を約4/3倍(つまり、𝐷_𝑅𝑁𝑁 ≈4𝐷/3)に拡大し、両方が同じモデル次元𝐷を使用する場合のMHAブロックのパラメータ数とほぼ一致するようにしました。

グリフィン。グローバル アテンションに対するリカレント ブロックの主な利点は、固定状態サイズを使用してシーケンスを要約するのに対し、MQA の KV キャッシュ サイズはシーケンスの長さに比例して増加することです。ローカル アテンションにも同じ特性があり、再帰ブロックとローカル アテンションを混合すると、この利点が維持されます。研究者たちは、ローカルアテンションが最近の過去を正確にモデル化でき、リカレントレイヤーが長いシーケンスにわたって情報を転送できるため、この組み合わせが非常に効果的であることを発見しました。

Griffin は、Transformer ベースラインと同じ残差モデルと MLP ブロックを使用します。しかし、MQA Transformer ベースラインや Hawk モデルとは異なり、Griffin はループ ブロックと MQA ブロックを組み合わせて使用​​します。具体的には、研究者らは、2 つの残差ブロックと、反復ブロック、そしてローカル (MQA) 注意ブロックを交互に配置する階層構造を採用しました。特に指定がない限り、ローカル アテンション ウィンドウのサイズは 1024 トークンに固定されます。

主なスケーリング結果を図1(a)に示します。 3 つのモデル ファミリはすべて、1 億から 70 億のパラメータのモデル サイズの範囲でトレーニングされましたが、Griffin には 140 億のパラメータ バージョンがあります。

下流タスクの評価結果を表1に示します。

ホークとグリフィンは両方とも非常に良いパフォーマンスを見せました。上記の表は、MMLU、HellaSwag、PIQA、ARC-E、および ARC-C の特徴正規化精度を報告し、また WinoGrande の絶対精度と部分スコアも報告します。モデルのサイズが大きくなるにつれて、Hawk のパフォーマンスも大幅に向上しました。Hawk-3B は、Mamba-3B の半分の数のトークンでトレーニングされているにもかかわらず、ダウンストリーム タスクでは Mamba-3B よりも優れたパフォーマンスを発揮します。 Griffin-3B は Mamba-3B よりも大幅に優れたパフォーマンスを発揮し、Griffin-7B と Griffin-14B は Llama-2 とほぼ 7 倍少ないトークンでトレーニングされているにもかかわらず、同等のパフォーマンスを発揮します。 Hawk は MQA Transformer ベースラインに匹敵しますが、Griffin はそれを上回ります。

クライアント側で再帰モデルを効率的にトレーニングする

研究者たちは、モデルの開発と拡張において、2つの大きな技術的課題に直面しました。まず、複数のデバイス間でモデルを効率的に分割する方法です。 2 番目は、TPU のトレーニング効率を最大化するために線形ループを効果的に実装する方法です。この論文では、これら 2 つの課題について説明し、Griffin と MQA ベースラインのトレーニング速度を経験的に比較します。

研究者らは、さまざまなモデル サイズとシーケンス長のトレーニング速度を比較し、トレーニング中のモデルの計算上の利点を調査しました。各モデル サイズでは、バッチあたりのトークンの合計数は固定されており、シーケンスの長さが長くなると、シーケンスの数は比例して減少します。

図 3 は、シーケンス長 2048 での Griffin モデルと MQA ベースライン モデルの相対的な実行時間をプロットしています。

推論速度

LLM の推論は 2 つの段階から構成されます。 「事前入力」段階では、プロンプトを受信して​​処理します。このステップは、実際にはモデルのフォワード パスです。プロンプトはシーケンス全体で並列処理できるため、この段階ではほとんどのモデル操作が計算に依存されます。したがって、事前入力フェーズでのトランスフォーマーと再帰モデルの相対速度は、トレーニング中に前述した速度と同様になると予想されます。

事前入力の後はデコード段階となり、研究者はモデルからトークンを自己回帰的に抽出します。以下に示すように、特にシーケンスの長さが長い場合、アテンションで使用されるキー値 (KV) キャッシュが大きくなり、リカレント モデルではデコード フェーズでのレイテンシが低くなり、スループットが高くなります。

推論速度を評価する際に考慮すべき主な指標が 2 つあります。 1 つ目はレイテンシです。これは、特定のバッチ サイズで指定された数のトークンを生成するのにかかる時間を測定します。 2 つ目はスループットです。これは、単一のデバイスで指定された数のトークンをサンプリングするときに、1 秒あたりに生成できるトークンの最大数を測定します。スループットは、サンプリングされたトークンの数とバッチ サイズを掛けてレイテンシで割った値で決まるため、レイテンシを減らすか、メモリ使用量を減らしてデバイスでより大きなバッチ サイズを使用することで、スループットを向上させることができます。高速な応答時間を必要とするリアルタイム アプリケーションの場合、レイテンシを考慮すると便利です。スループットは、特定のモデルから特定の時間にサンプリングできるトークンの最大数を示すため、考慮する価値があります。この特性は、人間のフィードバックによる強化学習 (RLHF) や言語モデル出力のスコアリング (AlphaCode で実行) などの他の言語アプリケーションを考慮すると魅力的であり、指定された時間内に大量のトークンを出力できることは魅力的な機能です。

ここで、研究者らはパラメータ 1B を使用したモデルの推論結果を研究しました。ベースラインに関しては、文献で一般的に使用されている標準の MHA トランスフォーマーよりも推論中に大幅に高速化する MQA トランスフォーマーと比較されます。研究者が比較したモデルは、i) MQA Transformer、ii) Hawk、iii) Griffin です。さまざまなモデルを比較するために、レイテンシとスループットを報告します。

図 4 に示すように、研究者は、バッチ サイズが 16、事前入力が空、および 4096 トークンが事前入力されたモデルのレイテンシを比較しました。

図1(b)は、空のプロンプトの後に512、1024、2048、4196トークンをサンプリングしたときの同じモデルの最大スループット(トークン/秒)を比較しています。

ロングコンテキストモデリング

この論文では、Hawk と Griffin がより長いコンテキストを使用して次のトークンの予測を改善する有効性についても調査し、推論中に外挿する能力を研究しています。また、コピーと検索の能力を必要とするタスクでの Griffin のパフォーマンスについても調査します。これは、モデルがそのようなタスクでトレーニングされた場合と、これらの能力が事前トレーニング済みの言語モデルを使用してテストされた場合の両方で行われます。

図 5 の左側の曲線グラフから、特定の最大長の範囲内で、Hawk と Griffin はどちらも、より長いコンテキストでの次のトークンの予測能力を向上させることができ、一般的にトレーニング中よりも長いシーケンスを推測できることがわかります (少なくとも 4 倍)。特に Griffin は、ローカル アテンション レイヤーで RoPE を使用する場合でも、推論のパフォーマンスが非常に優れています。

図 6 に示すように、選択的コピータスクでは、3 つのモデルすべてが完璧に機能します。このタスクでの学習速度を比較すると、Hawk は Transformer よりも大幅に遅くなります。これは、Mamba が同様のタスクで大幅に遅く学習することを発見した Jelassi ら (2024) の観察結果と同様です。興味深いことに、ローカル アテンション レイヤーを 1 つだけ使用しているにもかかわらず、Griffin の学習速度はほとんど低下せず、Transformer の学習速度に匹敵します。

詳細については、原著論文をお読みください。

<<:  シリコンバレーのアイアンマンがウルトラマンを訴える! GPT-4 がオープンソースになる見込みはありますか?

>>:  ビッグビデオモデルは世界モデルですか? DeepMind/UC Berkeley Chinese: 次のフレームを予測することで世界を変えることができる

ブログ    

推薦する

AIの世界は「データ」から「知識」へと移行している

人工知能(AI)革命は半世紀以上前に始まりました。過去 10 年間で、人工知能は学術科学の領域から私...

CTOは「大きな衝撃を受けた」:GPT-4Vの自動運転テストを5回連続で実施

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

PyTorch を学ぶには?簡単すぎる

多くの友人から、PyTorch の学習方法を尋ねられました。長期間の練習を経て、初心者が知っておく必...

MIT の FrameDiff ツールがリリースされ、AI を使用してタンパク質構造を設計し、医療開発の促進に役立てられるようになりました。

7月13日、 MITの研究者らは、医薬品開発の加速と遺伝子治療の改善を目的として、生成型人工知能を...

浅いモデルから深いモデルへ: 機械学習最適化アルゴリズムの概要

論文リンク: https://arxiv.org/abs/1706.10207概要: この論文では、...

...

工業情報化部:全国の指定規模以上の産業用ロボット製造企業の営業収入は531.7億元

最近、工業情報化省の公式ウェブサイトは、2020年1月から12月までのロボット産業の稼働状況を発表し...

Google、ユーザーの文章力向上を支援するAI文法チェッカーをリリース

8月8日、IT Homeの友人はGrammarlyツールが提供する「文法チェック」サービスを使用した...

DxRアルゴリズムのアイデアに基づいて設計されたルーティングアイテム配置構造の図

まず、タイトルには、検索構造ではなく、ルーティング項目の配置構造と書かれています。つまり、この構造を...

古代から皇帝の寿命は短かった。皇帝も負荷分散アルゴリズムを理解していたら...

[51CTO.com オリジナル記事] 古代の皇帝はハーレムに3000人の美女を抱えていたことは誰...

機械学習の問題を解決する一般的な方法があります!これを読んでください

編集者注: この記事は、WeChat パブリック アカウント「Big Data Digest」(ID...

人工知能の 10 大トレンドのうち、予想もしなかったものはどれですか?

[[237644]] 人工知能(AI)は、国家や企業が支配権を争う新たな技術の最前線です。マッキン...

機械学習アルゴリズムの長所と短所の比較と選択(要約)

この記事の目的は、現在の機械学習アルゴリズムの実用的かつ簡潔な一覧を提供することです。この記事の内容...

知識が求められるポストディープラーニング時代において、知識グラフをいかに効率的かつ自動的に構築できるのでしょうか?

日常生活では、情報を提示する次の 2 つの方法によく遭遇します。表示される情報量はどちらも同じですが...

...