「アンティーク」GPUでもDeepSeekと同じGRPOを実行できます。ビデオメモリは1/10しか必要とせず、コンテキストは10倍に増加します

「アンティーク」GPUでもDeepSeekと同じGRPOを実行できます。ビデオメモリは1/10しか必要とせず、コンテキストは10倍に増加します

オープンソースの微調整ツール Unsloth が新しいテクノロジーを携えて戻ってきました。前回のアップデートでは、GRPO に必要なメモリが 7GB に増加しました。今回は、独自の推論モデル Qwen2.5 (1.5B) をトレーニングするために必要な VRAM は 5GB のみで、前回より 2GB 少なくなっています。

この度、推論モデルトレーニング用のビデオメモリが完全に削除されました!

今回は、GRPO トレーニング推論モデルのコンテキストが 10 倍長くなり、必要なビデオ メモリが 90% 削減されました。

最新の Unsloth を使用すると、わずか 5GB のビデオ メモリで独自の推論モデルをトレーニングでき、Qwen2.5-1.5B の精度は低下しません。

5GB のビデオメモリはどういう意味ですか?

GTX 1060 など、2016 年以降にリリースされた GPU には 8GB のビデオ メモリが搭載されています。 2016 年の GTX 1060 は電子骨董品になりました。

現在、より長いコンテキストを実現することが、GRPO が直面している最大の課題の 1 つです。

他の GRPO LoRA/QLoRA 実装、さらには Flash Attention 2 (FA2) に基づく実装と比較しても、Unsloth の新しい効率的な GRPO アルゴリズムは、VRAM の 10% のみを使用しながらコンテキストの長さが 10 倍に増加します。

TRL+FA2 を使用した GRPO 設定では、Llama 3.1 (8B) は 20K コンテキスト長でのトレーニングに 510.8GB の VRAM を必要とします。

Unsloth は VRAM を 90% 削減し、わずか 54.3 GB にまで減らします。

ロングコンテキストVRAMを90%削減

Unsloth はさまざまなトリックを使用して、Flash Attention 2 を使用した標準実装と比較して、GRPO の VRAM 使用量を 90% 以上巧みに削減します。

コンテキストの長さが 20K で、各キューが 8 回生成される場合、Unsloth は Llama-3.1-8B モデルで 54.3GB の VRAM のみを使用しますが、標準実装では 510.8GB が必要です (Unsloth の場合は 90% の削減)。これはすべて、次の 3 つのブレークスルーのおかげです。

  1. 新しく設計されたメモリ効率の高い線形アルゴリズム: GRPO のメモリ使用量を 8 倍以上削減し、68.5 GB のメモリを節約します。 torch.compile では、num_generatinotallow=8 および 20K コンテキスト長で実際に高速になります。
  2. Unsloth が公開したスマート グラデーション チェックポイント アルゴリズムを活用します。中間アクティベーション値をシステム RAM に非同期的にオフロードするため、速度はわずか 1% 低下します。 num_generatinotallow=8 が必要なので、最大 372GB の VRAM が節約されます。中間勾配を蓄積することで、メモリ使用量をさらに削減できます。
  3. 他のパッケージの実装とは異なり、基盤となる推論エンジン (vLLM) と同じ GPU/CUDA メモリ空間を共有します。これにより、さらに 16GB の VRAM が節約されます。

Unsloth と Flash Attention 2 (FA2) に基づく標準実装のメモリ比較

一般的な GRPO 標準実装では、GRPO 損失を計算するために、サイズ (8、20K) の 2 つのロジットを作成する必要があります。これには、2*2 バイト*8(世代数)*20K(コンテキスト長)*128256(語彙サイズ) = 78.3GB の VRAM が必要です。

Unsloth は、長いコンテキストの GRPO のメモリ使用量を 8 分の 1 に削減するため、コンテキスト長が 20K の場合、必要な VRAM は 9.8 GB のみ追加されます。

KV キャッシュも 16 ビット形式で保存する必要があります。 Llama3.18B には 32 層あり、K と V のサイズは両方とも 1024 です。したがって、コンテキスト長が 20K の場合、メモリ使用量 = 2 * 2 バイト * 32 レイヤー * 20K コンテキスト長 * 1024 = バッチあたり 2.5 GB になります。

vLLM のバッチ サイズは 8 に設定できますが、VRAM を節約するために、計算では 1 に維持されます。それ以外の場合は、KV キャッシュを保存するのに 20 GB が必要です。

数学の原理

グループ相対ポリシー最適化 (GRPO) は、DeepSeek が昨年発表した論文から生まれました。

生涯で DeepSeek の論文を 1 つしか読めないとしたら、ネットユーザーは GRPO を最初に提案した DeepSeekMath の論文を選ぶことを推奨しています。

論文リンク: https://arxiv.org/abs/2402.03300

その後、DeepSeek の論文では、GRPO アルゴリズムを使用して DeepSeek-R1 が作成されました。

問題が見つかりました

ここでは、Hugging Face の TRL GRPO 実装を使用します。

TRL 実装の式は次のようになります。

ここでは、逆 KL ダイバージェンスが使用されます (順方向 KL ダイバージェンスの代わりに)。 β は 0.04 に設定されたスケーリング係数であり、A はすべての報酬関数を考慮した後の利点の値です。 q は新しくトレーニングされたモデルであり、P は元の参照モデルです。

次に、実装では逆 KL ダイバージェンスを次のように計算することに注意してください。

しかし、これは本当に正しいのでしょうか?

まず、類似の用語を導き出して整理してみます。

それはどういう意味ですか?私の実装では、q (新しい分布項) との乗算が抜けているのでしょうか?

しかし、GRPO が DeepSeek-Math 論文の 14 ページで初めて紹介されたときと同様に、それは正しいようです。

DeepSeek-Math 論文 14 ページ: 損失関数に KL ダイバージェンスを追加して GRPO アルゴリズムを正規化する

同様に、John Schulman のブログでも、逆 KL 項の不偏推定には実際には追加の q 項は必要ないと述べています。

リンクアドレス: http://joschu.net/blog/kl-approx.html

ブログでご覧ください:

興味深い現象も発見されました。

 torch.exp(qq.detach()) * advantages.unsqueeze(1)

これは 1 になるはずですよね?

Hugging FaceによるTRL GRPO実装

実際には、これが必要であることがわかりました。autograd エンジンが勾配を正しく伝播していない可能性があるようです。

そのため、4つの実験が実施されました。

  1. リファレンス実装を使用した従来のGRPO(赤線)
  2. 切り離しコード(青線)を削除します
  3. 前述の完全な逆KLに続いて、追加の項(黄色の線)を追加します。
  4. 代わりに前方KLダイバージェンスを使用する(緑の線)

全体的に、デタッチを削除するとトレーニングが中断されることは明らかなので、デタッチは保持する必要があります。これにはおそらくさらなる調査が必要になるでしょう。他の実装も同様のようですね?さまざまな効果を観察するには、モデルをより長い期間実行する必要がある場合があります。

すべての実装では、logsumexp トリックも利用されます。

効率的なGRPOアルゴリズム

しかし、中国のエンジニア Horace He による線形クロスエントロピーの実装が unsloth にインスピレーションを与え、GRPO にうまく適用されるとは思っていませんでした。

Meta で PyTorch に取り組んでいる Horace He 氏

実際、unsloth はいくつか驚くべき点を発見しました。

1 GRPO リファレンス実装では、順方向 KL ダイバージェンスではなく逆方向 KL ダイバージェンスを使用します。

2 正しく処理されない場合、自動混合精度スケーリングを使用して float16 混合精度 (および float8) に線形クロスエントロピーを直接実装すると、クラッシュが発生する可能性があります。

3 GRPO 損失の実装において、主に逆 KL ダイバージェンスの定式化において、いくつかの奇妙な点が見つかりました。

線形交差商リンク: https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899

その他の機能

GRPOの完全なログ記録

以前は、unsloth は合計集計報酬関数自体のみを表示していましたが、この新しいバージョンでは、すべての報酬関数の完全なログ詳細が提供されます。

GRPO をパッチするために関数を呼び出す必要はもうありません。つまり、新しいバージョンではこれを自動的に処理し、次のコードを削除できます。

 from unsloth import PatchFastRL PatchFastRL("GRPO", FastLanguageModel)

vLLM推論オプション

FP8 KV キャッシュが vLLM でも利用できるようになりました。これにより、新しい GPU (RTX 3090、A100 以降) で KV キャッシュ スペースの使用量を 2 倍削減できます。

 model, tokenizer = FastLanguageModel.from_pretrained( model_name = "meta-llama/meta-Llama-3.1-8B-Instruct", max_seq_length = max_seq_length, load_in_4bit = True, fast_inference = True, max_lora_rank = lora_rank, gpu_memory_utilization = 0.6, float8_kv_cache = True, )

vLLM で min_p=0.1 またはその他のサンプリング パラメータを使用する場合は、vLLM の SamplingParams パラメータに何かを渡すこともサポートされています。

 max_prompt_length = 256 from trl import GRPOConfig, GRPOTrainer from unsloth import vLLMSamplingParams vllm_sampling_params = vLLMSamplingParams( min_p = 0.1, seed = 3407, ... ) training_args = GRPOConfig( ... vllm_sampling_params = vllm_sampling_params, temperature = 1.5, )


<<:  ソフトウェア業界における破壊的革命: AIはすべてのものを食べるだけでなく、すべてそのものになる

>>:  具現化された知能の新時代! VLAは、UIナビゲーションとロボット操作を備えた最強の基本モデルMagmaを歓迎します

ブログ    
ブログ    
ブログ    
ブログ    
ブログ    

推薦する

...

ChatGPTはカスタムコマンドを起動します。一度言って覚えておけば、話すたびにそれに従います。

「私は小学校の理科の先生です。科学的な概念について説明していただきたいです。例や類推などのテクニッ...

ChatGPT が突然大きなバグを発見しました!フル機能のGPT-4は無料で使用でき、ネットユーザーは大喜びしている

11月15日、OpenAIは突然、ChatGPT Plusの新規ユーザー登録を停止すると発表しました...

プロフェッショナルスキルを向上させる: 10のNLPテクニックを理解して習得する

1. 感情分析感情分析とは、ツイート、製品レビュー、顧客からのフィードバックなどのテキストの背後にあ...

ニューヨーク・タイムズは、自社のニュース記事をAIモデルの訓練に利用することを禁止し、OpenAIを訴えることを検討している。

NPRによると、OpenAIは、自社の人工知能(AI)モデルのトレーニングにニューヨーク・タイムズ...

メタ:メタバース製品は引き続き顔認識技術を使用する

[[433492]] 11月5日、海外メディアの報道によると、フェイスブックは今週、同社のプラットフ...

...

ロボットは人間と機械の協働チームの「リーダー」になれるでしょうか?どのように機能しますか?

ロボット技術の発展により、ロボットは実生活においてますます重要な役割を果たすようになるでしょう。人間...

...

科学技術の時代におけるあらゆる産業の発展を可能にするAIIA2020人工知能開発者会議が開幕

人工知能は科学技術革命を牽引する重要な原動力として、国家戦略計画や産業界の注目の的となり、オープンソ...

ビッグデータと人工知能の関係

[[342758]]人工知能教育は最も美しい新しいインフラです人工知能のアルゴリズムの中にはデータ...

...

AIの「不確実な時代」にどう向き合うか

AIの拡大する影響私たちの日常生活における AI の影響はますます明らかになってきています。 AI ...

...

...