RTX 4090が制限されている時代に、大規模モデルにRLHFを使用するより効率的な方法が登場

RTX 4090が制限されている時代に、大規模モデルにRLHFを使用するより効率的な方法が登場

  • 論文リンク: https://arxiv.org/abs/2310.10505
  • 著者: Li Ziniu、Xu Tian、​​Zhang Yushun、Yu Yang、Sun Ruoyu、Luo Zhiquan
  • 機関: 香港中文大学、深圳、深圳ビッグデータ研究所、南京大学、南京仙学院
  • オープンソースコード: https://github.com/liziniu/ReMax

特に記載がない限り、すべての画像は新聞からのものです。

背景

今年は、ChatGPT が主導する大規模言語モデル (LLM) があらゆる面で注目を集め、学術界やビジネス界で GPU などのコンピューティング リソースの需要が急増しました。

左の写真はDALL・E3、右の写真はDALL・E3

たとえば、Llama2-7B モデルの教師あり微調整 (SFT) には 80 GB を超えるメモリが必要です。しかし、多くの場合、これだけでは十分ではありません。人間と一致するためには、大規模な言語モデルも RLHF (人間からのフィードバックによる強化学習) でトレーニングする必要があります。 RLHF の GPU 消費量は SFT の 2 倍以上になることが多く、トレーニング時間は 6 倍以上になることがあります。

最近、米国政府は、H100やH800などのNvidia GPU製品の中国市場への参入を制限すると発表しました。この規定は間違いなく、中国の大規模言語モデル(LLM)と人工知能の開発に大きな抵抗を加えることになるだろう。 RLHF のトレーニング コスト (GPU 消費量とトレーニング時間) を削減することは、LLM の開発にとって非常に重要です。

モチベーション

RLHF は次の 3 つの段階から構成されます。

1. 教師あり微調整 (SFT)

2. 比較データから報酬モデルを学習します。

3. 強化学習 (RL) アルゴリズムを使用して報酬を最大化します。

画像出典: InstructGPT 論文

RLHF の主な計算オーバーヘッドは、第 3 段階 (報酬の最大化) から発生することがわかります。 DeepSpeed-Chat レポートから、第 3 ステージのトレーニング時間は最初の 2 つのステージの合計時間の 4 倍以上であることがわかります。さらに、私たちの経験によれば、第 3 ステージの GPU 消費量は、最初の 2 ステージの 2 倍以上になります。

DeepSpeed-Chat テクニカルレポートからの画像

現在、RLHF フェーズ 3 の主な計算上のボトルネックは何ですか?

この段階での計算ボトルネックの主な原因は、現在使用されている RL アルゴリズム、つまり PPO アルゴリズムであることがわかりました。 PPO アルゴリズムは、普遍的な RL 問題を解決するための最も人気のあるアルゴリズムの 1 つであり、成功例も数多くあります。ここでは PPO の技術的な詳細は省略し、PPO の主要コンポーネントである価値モデルに焦点を当てます。価値モデルは、特定の戦略の期待される長期リターンを効果的に推定するためにトレーニングする必要があるニューラル ネットワークです。価値モデルは PPO に優れたパフォーマンスをもたらしますが、RLHF タスクに大きな計算オーバーヘッドも生じます。たとえば、人間の好みに合わせるために、PPO の価値モデルは通常 LLM とサイズが似ており、ストレージ要件が 2 倍になります。さらに、価値モデルをトレーニングするには、その勾配、アクティベーション、およびオプティマイザーの状態を保存する必要があり、これにより GPU ストレージ要件がさらに 4 倍近く増加します。要約すると、PPO とその価値モデル (およびそのトレーニング関連部分) は、RLHF の報酬最大化段階における主な計算上の障害となっています。

PPOと比較すると、ReMaxは軽量なアルゴリズムである。

アイデア

PPO よりも RLHF に適したアルゴリズムを見つけることは可能ですか?

私たちが出した答えは「はい」です。これは、PPO と価値モデルが、RLHF のような特定の問題ではなく、一般的な RL 問題向けに設計されているためです (RLHF は RL 問題のサブクラスにすぎません)。興味深いことに、RLHF には PPO では使用されていない 3 つの重要な構造があることがわかりました。

1. 高速シミュレーション: 軌跡 (つまり、LLM での応答全体) は、時間のオーバーヘッドをほとんどかけずに、非常に短時間 (1 秒未満) で実行できます。

2. 決定論的遷移: コンテキストは過去のトークンと現在生成されているトークンに決定論的に依存します。

3. 軌道レベルの報酬: 報酬モデルは、応答が完了した場合にのみ報酬値を提供します。

これら 3 つの観察から、RLHF 問題において価値モデルが「冗長」であることは容易にわかります。これは、価値モデル設計の本来の意図が、ランダム環境でのサンプル効率と、低速シミュレーション環境での計算効率を達成することにあるためです。ただし、RLHF ではこれは必要ありません。

ReMax は RLHF 用に設計されたアルゴリズムですが、PPO は一般的な RL 用に設計されたアルゴリズムです。

方法

リマックス

ReMax アルゴリズムは、古いポリシー勾配アルゴリズム REINFORCE に基づいています。REINFORCE で使用されるポリシー勾配推定器を次の図に示します。

勾配推定器の強化

REINFORCE は、最適化に応答報酬を直接使用し、一般的な RL アルゴリズムのように中間ステップの報酬と価値関数を知る必要がないため、計算レベルで RLHF タスクの 3 つの特性を活用できます。ただし、戦略のランダム性により、REINFORCE 勾配推定器には高分散の問題があり (Richard Sutton の RL 書籍で指摘されています)、モデルトレーニングの有効性に影響します。そのため、以下の 2 つの図に示すように、REINFORCE は RLHF タスクでパフォーマンスが低下します。

REINFORCEは計算コストは​​低いがパフォーマンスは低い


REINFORCEの(ランダムな)勾配はReMaxの勾配よりもはるかに大きい。

この問題を解決するために、ReMax は貪欲応答の報酬をベースライン値として使用して勾配推定器を構築します。具体的な式は次のとおりです。

ReMax勾配推定器

貪欲な応答に対する報酬は、期待される報酬の良い近似値として見ることができることに注意してください。理想的なケース ( ) では、ランダム変数 に対してとなるため、推定値の分散は小さくなることが期待できます

下の図はReMaxのアルゴリズムフローを示しており、赤いボックスはコアアルゴリズムの変更を示しています。

ReMaxアルゴリズムプロセス

理論上の保証

ReMax で使用される勾配推定量は、依然として真のポリシー勾配の不偏推定量であることを示します。

詳細な理論的紹介については論文を参照してください。

アルゴリズムの利点

  • ReMax のコアは 6 行のコードで実装できます。対照的に、PPO では、重要度サンプリング、一般化利点推定 (GAE)、価値モデル学習などの追加モジュールが導入されています。
  • ReMax にはハイパーパラメータがほとんどありません。対照的に、PPO には、重要度サンプリング クリッピング比、GAE 係数、価値モデル学習率、オフポリシー トレーニング エポックなどの追加のハイパーパラメータがあります。これらのハイパーパラメータの調整には多くの時間が必要です。
  • ReMax は理論的にはメモリを約 50% 節約できます。 PPO と比較すると、ReMax は価値モデルに関連するすべてのコンポーネントを正常に削除し、メモリのオーバーヘッドを大幅に削減します。計算により、ReMax は PPO と比較して約 50% のメモリを節約できることがわかりました。

効果

効果

  • ReMaxはPPOと同様に効果的に報酬を最大化できます

OPT-1.3Bでは、ReMaxは効果的に報酬を最大化することができます

OPT-1.3BではReMaxトレーニングは非常に安定しています

  • GPT-4評価(LIMAテスト問題)では、ReMaxによって得られた戦略はSFTやPPOよりも優れている。

GPT4スコアリングでは、ReMaxによって得られたモデルの方が優れていることが示されています。

効率

  • ReMax は GPU メモリを約 50% 節約できます。 ReMax は、価値モデルとそのトレーニング部分 (勾配、オプティマイザー、アクティベーション) を削除するため、GPU メモリ要件が大幅に削減されます。 Llama2-7B を考慮すると、PPO は 8xA100-40GB マシンでは実行できませんが、ReMax は実行できます。

Llama2-7Bでは、ReMaxはGPUメモリを約50%節約できる

  • ReMax はトレーニングを 2 倍高速化できます。各ラウンドで、ReMax は 2 世代と 1 回のバックプロパゲーションを呼び出しますが、PPO は 1 世代と 2 回のバックプロパゲーションを使用します。大規模なモデルの場合、生成時間はバックプロパゲーション時間よりも短くなるため、ReMax は理論的にはトレーニングの約 2 倍の高速化を実現できます。

汎用性

RLHF タスクに加えて、RL アルゴリズムとしての ReMax は、従来の NLP タスクにも適用できます。この論文では、報酬モデルが比較データから学習されない GPT-2 上の映画レビュー継続タスクを検討します。実験的観察によると、ReMax は 2.2 倍のトレーニング高速化と 60% の GPU メモリ節約を実現できます。

従来の NLP タスク (テキスト継続) では、ReMax は PPO と比較して 2.2 倍の高速化を達成しました。

要約する

最後に、私たちの実験から得た PPO に対する ReMax の主な利点を簡単にまとめます。

  • よりシンプルな実装: ReMax のコアは 6 行のコードで実装できます。これは、PPO の多くの複雑なコード構成要素とはまったく対照的です。
  • メモリ オーバーヘッドの削減: 価値モデルとそのトレーニング コンポーネント全体が削除されたため、ReMax は PPO と比較して GPU メモリを約 50% 節約します。
  • ハイパーパラメータの削減: ReMax は、GAE 係数、価値モデルの学習率、重要度サンプリング エポック、ミニバッチ サイズなど、価値モデルのトレーニングに関連するすべてのハイパーパラメータを正常に削除します。これらのハイパーパラメータは、多くの場合、問題に敏感であり、調整が困難です。 ReMax は RLHF 研究者にとってより親しみやすいものであると考えています。
  • より高速なトレーニング速度: GPT2 (137M) の実験では、実際の実行時間に関して、ReMax は PPO と比較して 2.2 倍高速であることが確認されました。高速化は、各反復における ReMax の計算オーバーヘッドが低いことから生まれます。私たちの計算によると、この高速化の利点は、より大きなモデルでも維持されます (PPO が十分に大きなメモリに正常に展開できると仮定)。
  • 優れたパフォーマンス: 上記のように、ReMax は中規模の実験で PPO と同等のパフォーマンスを達成し、場合によっては PPO を上回るパフォーマンスを発揮します (おそらく、ReMax に適切なハイパーパラメータを見つけるのが簡単なためです)。この優れたパフォーマンスは、より大きなモデルにも拡張できると推測されます。

<<: 

>>:  OpenAIがついにオープン:DALL-E 3の論文が発表され、ChatGPTが開始、著者の半数が中国人

ブログ    
ブログ    
ブログ    

推薦する

中国の自動運転が新たなブレークスルーをもたらす:百度世界2020のCCTV生中継で完全無人運転を体験

中国の自動運転は新たな進歩を遂げ、無人運転の時代が到来した。 9月15日、百度はCCTVニュースと提...

ワンクリックで 2D GAN を「3D」化、CUHK が教師なし 3D 再構築の新しい方法を提案

CUHK の MMLab チームによるこの研究は、2 次元 GAN がオブジェクトの 3 次元構造を...

ネットワークにおける機械学習の実際の応用

インターネット接続の需要が急速に高まっているため、企業にはネットワーク インフラストラクチャ、パフォ...

グラフニューラルネットワークは CV の未来でしょうか?中国科学院ソフトウェア研究所は、ViTを上回る新しいCVモデルViGをリリースした。

コンピュータービジョンのネットワーク構造は新たな革命を迎えようとしているのでしょうか?畳み込みニュー...

...

アルゴリズムの力: プログラマーはデスクトップ コンピューターを使用して、スーパーコンピューターの世界記録を破ります

有名なフランス人プログラマー、ファブリス・ベラール氏は最近、普通のデスクトップコンピュータ(2,00...

機械学習における数学的意義

機械学習におけるパフォーマンスを主張するために使用される指標については、ほとんど議論されていません。...

15億パラメータのモデルを2日間でトレーニングし、国内オープンソースプロジェクトがNvidiaのMegatron-LMを上回った

AIの現在の動向において、その徹底的な発展に影響を与える矛盾は何でしょうか?一方では、大型モデルが大...

人工知能は寒い冬を迎え、自動運転車の開発は妨げられている

懐疑論者は、完全な自動運転の実現は業界が考えているよりもずっと先のことかもしれないと述べている。 [...

人工知能と機械学習でよく使われるアルゴリズムの概要と、よく使われる各アルゴリズムの精度の比較

[[319322]]この記事では、一般的に使用されている機械学習アルゴリズムの概要と、一般的に使用さ...

サーバーレス コンピューティングによる機械学習の解決策は何でしょうか?

1. 機械学習とサーバーレス学習1.1. 機械学習 (ML) はアプリケーション シナリオでどのよ...

人工知能に関しては 5 つの主要な考え方があります。あなたはどれを支持しますか?

将来の雇用状況は依然としてテクノロジー大手やCEOによって決定されますが、人工知能の将来は依然として...

人工知能(AI)時代に誰もが身につけるべき9つのソフトスキル

今日の人工知能、ビッグデータ、自動化の時代では、技術的なスキルとデータリテラシーが非常に重要です。し...

Telstra はディープラーニングを使用してネットワークの課題に取り組んでいます。

テルストラは、機器の故障を早期に予測し、音声やテキストによる詐欺に対抗する方法を見つけるために、ネッ...

スマート病院: 将来の医療技術のガイドラインとトレンド

スマート病院とは何ですか?最も伝統的な病院でさえ、人、プロセス、資産の広大なネットワークを持つ複雑な...