モデルの並列処理により、ビジョンタスクのパフォーマンスが向上します。しかし、現在のところ、混合精度などの他の SOTA 手法と同じくらい簡単にモデル並列処理を採用できる標準ライブラリはありません。 最近、メリーランド大学カレッジパーク校のコンピュータサイエンス学部の研究者である Kaiyu Yue 氏が、PyTorch テンソルを並列シャードにスライスするための軽量エンジンである TorchShard ツールをオープンソース化しました。 TorchShard は、モデルに多数の線形レイヤー (BERT、GPT など) や多数のクラス (数百万) がある場合に、GPU メモリを削減し、トレーニングをスケールできます。PyTorch と同じ API 設計になっています。 プロジェクトアドレス: https://github.com/KaiyuYue/torchshard BERT や GPT などの非常に大規模なモデルは、NLP 分野のアプリケーションでトレンドになりつつあります。しかし、このような大規模なモデルをトレーニングするには、メモリ制限の問題に直面します。この問題を解決するために、研究者は Megatron-LM と PyTorch-Lightning モデルの並列処理を使用してトレーニングを拡張しました。このうち、Megatron-LM は大規模なトレーニング言語モデルにのみ焦点を当てていますが、PyTorch-Lightning は DeepSpeed などのシャード化されたオプティマイザー状態と勾配のみに基づいています。 コンピューター ビジョン タスクでは、Transformer ベースのモデル、MLP モデル、または数百万のクラスのトレーニング モデルをトレーニングするときに同じ問題が発生します。 TorchShard の目標は次のとおりです。
TorchShard は、Megatron-LM の中心にあるモデル並列ユニット (MPU) を完全に書き直したものです。最も重要なのは、TorchShard は PyTorch と同じ API 設計になっていることです。つまり、すべてのサブクラスとサブ関数は PyTorch と同じままです。たとえば、元の線形レイヤー torch.nn.Linear を並列にしたい場合は、次のように torch を ts に変換し、サブクラス nn.ParallelLinear を dim パラメータで呼び出します。
これに加えて、TorchShard は DDP と併用すると、シャード チェックポイントの保存と読み込み、シャード パラメータの初期化、複数のマシンと GPU にわたるテンソルの処理など、さまざまな機能をサポートします。詳細は以下の通りです。
TorchShard を使い始めるにはどうすればいいですか?インストール要件: Python バージョン 3.6 以上 (含む) および PyTorch バージョン 1.9.0 以上 (含む)。 pip 経由で TorchShard ライブラリをインストールします。
ここでは、ImageNet での ResNet-50 のトレーニングを例として、わずか数行のコードでプロジェクトで TorchShard を使用する方法を示します。通常、ResNet-50 の設計パラダイムは、下の図 1 に示すように、畳み込みブロックと完全接続層の 2 つの部分で構成されます。一般に、データセットに応じてクラスの数が多いため、最終線形層には畳み込みブロックよりも多くのパラメーターがあります。そこで、最後の線形レイヤーをスライスして、その最大サイズを確認します。 図 1: DDP および DDP + TorchShard フォワード トレーニング フロー。 上の図 1 では、従来の DDP トレーニング パラダイムが左側に示されています。 2 つのクラスがあると仮定すると、DDP は各クラスに重複したモデル パラメータを強制的に設定させます。ただし、TorchShard はレイヤー パラメータをさまざまなレベルに分割するため、全体的な GPU メモリが削減されます。ここで、ImageNet の公式トレーニング スクリプトにいくつかのコードを追加すると、修正されたバージョンが TorchShard プロジェクトの一部になります。 まず、torchshard をインポートします。
次に、DDP プロセス グループを初期化するのと同じ方法で、モデル並列プロセス グループを初期化する必要があります。ターゲット レイヤーからスライスするシャードの数を torchshard に指示する関数パラメータを設定するだけで済みます。
次に、モデルは並列バージョンに変換され、特別な処理なしでモデル全体を変換ヘルパー関数に直接入力できるようになります。
また、入力テンソルに応じて元の PyTorch バージョンと並列バージョンを切り替えることができる損失関数 torchshard.nn.ParallelCrossEntropy も忘れないでください。たとえば、入力テンソルが torchshard 並列レイヤーによって生成される場合、torchshard.nn.ParallelCrossEntropy は損失値を並列で計算します。
モデル並列モード (TorchShard) とデータ並列モード (DDP) が連携して動作する場合、並列レイヤーの入力を処理する必要があります。パラメータとトレーニングデータはレベルごとに異なります。したがって、ResNet の並列線形レイヤーの前に入力テンソルを収集します。
同様に、損失を計算する前にターゲット テンソルを収集します。
最後に、TorchShard 関数を使用すると、チェックポイントの保存と読み込みが非常に簡単になります。 TorchShard は、チェックポイントを保存するための torchshard.collect_state_dict という基本関数と、チェックポイントを読み込むための torchshard.relocate_state_dict という基本関数を提供します。 チェックポイントを保存します:
チェックポイントをロードします:
ImageNet でのシャード トレーニング用のコードの追加が完了したので、クラス数、つまり最後の線形レイヤーの出力特徴次元を増やすことでスケールアップできます。トレーニング スクリプトは torchshard/project/imagenet にあります。次の図は、クラス数が 1,000,000 以下の 8 個の NVIDIA TITAN-XP (12196 MiB) GPU と、クラス数が 2,000,000 の 16 個の GPU での ResNet-50 トレーニングのスケーラビリティを示しています。 図 2: さまざまな並列化戦略で標準の ResNet トレーニング設定 (入力サイズ 224、バッチ サイズ 256) を使用した場合の GPU メモリ コスト。 ZeROでAMPを使用するTorchShard は、Automatic Mixed Precision AMP や ZeRO などの他の技術と、シンプルで自然な PyTorch の方法で組み合わせることができます。
図 3: 標準の ResNet トレーニング設定 (入力サイズ 224、バッチ サイズ 256) を使用したさまざまな並列戦略と AMP での GPU メモリの使用コスト。 ZeRO は DeepSpeed のコアであり、PyTorch >= 1.9.0 で使用されます。関数をテストする場合は、スクリプトの最新バージョンをインストールして実行してください。コードは次のとおりです。
図 4: さまざまな並列化戦略と ZeRO オプティマイザーを使用した標準 ResNet トレーニング セットアップ (入力サイズ 224、バッチ サイズ 256) の GPU メモリ コスト。 さらに、TorchShard は、カスタム並列レイヤーの実装を簡素化するための基本的な Python API と対応するテンプレート ファイルも提供します。 研究者たちは TorchShard の開発を継続します。たとえば、TorchShard の次の機能は、torch.utils.data.DistributedSampler の命名に続く新しいデータ サンプラー torchshard.utils.data.DistributedGroupSampler です。このサンプラーは、ユーザーが M 方向のデータ並列処理と N 方向のモデル並列処理を構築できるように設計されており、DDP の DistributedSampler と同じくらいシンプルです。ユーザーが行う必要があるのは、モデル並列グループ番号を設定することだけです。そうすると、DistributedGroupSampler によって、同じモデル並列グループ内のモジュールに同じトレーニング データが含まれるようになります。 |
<<: ニッチから人気へ: 世界的な AI イノベーションが「ソフト」になった理由
>>: 二重あごをなくすコツがある。浙江大学の2000年代生まれの大学生が、ACM SIGGRAPHで発表した新しい美容アルゴリズムを開発
著者 | ブライアン・マクマホン、翻訳者 | bluemin、編集者 | 陳彩仙1950年代にDNA...
この時代の変化のスピードは想像を絶します!次から次へと生み出される想像力豊かな革新は、目を見張るほど...
[51CTO.com クイック翻訳] 今日言及された事故のほとんどはAI自体と直接関係はありませんが...
Logreduce は、大量のログ データから異常を検出することでデバッグ時間を節約できます。継続的...
AI 競争が始まっており、世界中の企業が AI ベースのイノベーションにおける世界的優位性を求めて競...
この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...
私は週末に AI で遊んでいて、個人的な知識ベースをローカルに展開しています。基本的には OpenA...
私たちの生活、仕事、交流の仕方に革命をもたらす技術の進歩によって、未来は常に形を変えています。今後 ...
ZKの紹介ZK = 動物園の飼育係ZK は、マイクロサービス ソリューションにおけるサービス登録と検...
6月13日にリリースされたChatGPTの関数呼び出し機能は、自然言語の世界と既存のプログラミング言...
Github Copilot のような人工知能コードアシスタントは、開発者の開発効率と生産性を大幅に...
最近、「被験者 3」について多かれ少なかれ耳にしたことがあるかもしれません。握手、軽く捻挫した足、リ...