モデルの並列処理により、ビジョンタスクのパフォーマンスが向上します。しかし、現在のところ、混合精度などの他の SOTA 手法と同じくらい簡単にモデル並列処理を採用できる標準ライブラリはありません。 最近、メリーランド大学カレッジパーク校のコンピュータサイエンス学部の研究者である Kaiyu Yue 氏が、PyTorch テンソルを並列シャードにスライスするための軽量エンジンである TorchShard ツールをオープンソース化しました。 TorchShard は、モデルに多数の線形レイヤー (BERT、GPT など) や多数のクラス (数百万) がある場合に、GPU メモリを削減し、トレーニングをスケールできます。PyTorch と同じ API 設計になっています。 プロジェクトアドレス: https://github.com/KaiyuYue/torchshard BERT や GPT などの非常に大規模なモデルは、NLP 分野のアプリケーションでトレンドになりつつあります。しかし、このような大規模なモデルをトレーニングするには、メモリ制限の問題に直面します。この問題を解決するために、研究者は Megatron-LM と PyTorch-Lightning モデルの並列処理を使用してトレーニングを拡張しました。このうち、Megatron-LM は大規模なトレーニング言語モデルにのみ焦点を当てていますが、PyTorch-Lightning は DeepSpeed などのシャード化されたオプティマイザー状態と勾配のみに基づいています。 コンピューター ビジョン タスクでは、Transformer ベースのモデル、MLP モデル、または数百万のクラスのトレーニング モデルをトレーニングするときに同じ問題が発生します。 TorchShard の目標は次のとおりです。
TorchShard は、Megatron-LM の中心にあるモデル並列ユニット (MPU) を完全に書き直したものです。最も重要なのは、TorchShard は PyTorch と同じ API 設計になっていることです。つまり、すべてのサブクラスとサブ関数は PyTorch と同じままです。たとえば、元の線形レイヤー torch.nn.Linear を並列にしたい場合は、次のように torch を ts に変換し、サブクラス nn.ParallelLinear を dim パラメータで呼び出します。
これに加えて、TorchShard は DDP と併用すると、シャード チェックポイントの保存と読み込み、シャード パラメータの初期化、複数のマシンと GPU にわたるテンソルの処理など、さまざまな機能をサポートします。詳細は以下の通りです。
TorchShard を使い始めるにはどうすればいいですか?インストール要件: Python バージョン 3.6 以上 (含む) および PyTorch バージョン 1.9.0 以上 (含む)。 pip 経由で TorchShard ライブラリをインストールします。
ここでは、ImageNet での ResNet-50 のトレーニングを例として、わずか数行のコードでプロジェクトで TorchShard を使用する方法を示します。通常、ResNet-50 の設計パラダイムは、下の図 1 に示すように、畳み込みブロックと完全接続層の 2 つの部分で構成されます。一般に、データセットに応じてクラスの数が多いため、最終線形層には畳み込みブロックよりも多くのパラメーターがあります。そこで、最後の線形レイヤーをスライスして、その最大サイズを確認します。 図 1: DDP および DDP + TorchShard フォワード トレーニング フロー。 上の図 1 では、従来の DDP トレーニング パラダイムが左側に示されています。 2 つのクラスがあると仮定すると、DDP は各クラスに重複したモデル パラメータを強制的に設定させます。ただし、TorchShard はレイヤー パラメータをさまざまなレベルに分割するため、全体的な GPU メモリが削減されます。ここで、ImageNet の公式トレーニング スクリプトにいくつかのコードを追加すると、修正されたバージョンが TorchShard プロジェクトの一部になります。 まず、torchshard をインポートします。
次に、DDP プロセス グループを初期化するのと同じ方法で、モデル並列プロセス グループを初期化する必要があります。ターゲット レイヤーからスライスするシャードの数を torchshard に指示する関数パラメータを設定するだけで済みます。
次に、モデルは並列バージョンに変換され、特別な処理なしでモデル全体を変換ヘルパー関数に直接入力できるようになります。
また、入力テンソルに応じて元の PyTorch バージョンと並列バージョンを切り替えることができる損失関数 torchshard.nn.ParallelCrossEntropy も忘れないでください。たとえば、入力テンソルが torchshard 並列レイヤーによって生成される場合、torchshard.nn.ParallelCrossEntropy は損失値を並列で計算します。
モデル並列モード (TorchShard) とデータ並列モード (DDP) が連携して動作する場合、並列レイヤーの入力を処理する必要があります。パラメータとトレーニングデータはレベルごとに異なります。したがって、ResNet の並列線形レイヤーの前に入力テンソルを収集します。
同様に、損失を計算する前にターゲット テンソルを収集します。
最後に、TorchShard 関数を使用すると、チェックポイントの保存と読み込みが非常に簡単になります。 TorchShard は、チェックポイントを保存するための torchshard.collect_state_dict という基本関数と、チェックポイントを読み込むための torchshard.relocate_state_dict という基本関数を提供します。 チェックポイントを保存します:
チェックポイントをロードします:
ImageNet でのシャード トレーニング用のコードの追加が完了したので、クラス数、つまり最後の線形レイヤーの出力特徴次元を増やすことでスケールアップできます。トレーニング スクリプトは torchshard/project/imagenet にあります。次の図は、クラス数が 1,000,000 以下の 8 個の NVIDIA TITAN-XP (12196 MiB) GPU と、クラス数が 2,000,000 の 16 個の GPU での ResNet-50 トレーニングのスケーラビリティを示しています。 図 2: さまざまな並列化戦略で標準の ResNet トレーニング設定 (入力サイズ 224、バッチ サイズ 256) を使用した場合の GPU メモリ コスト。 ZeROでAMPを使用するTorchShard は、Automatic Mixed Precision AMP や ZeRO などの他の技術と、シンプルで自然な PyTorch の方法で組み合わせることができます。
図 3: 標準の ResNet トレーニング設定 (入力サイズ 224、バッチ サイズ 256) を使用したさまざまな並列戦略と AMP での GPU メモリの使用コスト。 ZeRO は DeepSpeed のコアであり、PyTorch >= 1.9.0 で使用されます。関数をテストする場合は、スクリプトの最新バージョンをインストールして実行してください。コードは次のとおりです。
図 4: さまざまな並列化戦略と ZeRO オプティマイザーを使用した標準 ResNet トレーニング セットアップ (入力サイズ 224、バッチ サイズ 256) の GPU メモリ コスト。 さらに、TorchShard は、カスタム並列レイヤーの実装を簡素化するための基本的な Python API と対応するテンプレート ファイルも提供します。 研究者たちは TorchShard の開発を継続します。たとえば、TorchShard の次の機能は、torch.utils.data.DistributedSampler の命名に続く新しいデータ サンプラー torchshard.utils.data.DistributedGroupSampler です。このサンプラーは、ユーザーが M 方向のデータ並列処理と N 方向のモデル並列処理を構築できるように設計されており、DDP の DistributedSampler と同じくらいシンプルです。ユーザーが行う必要があるのは、モデル並列グループ番号を設定することだけです。そうすると、DistributedGroupSampler によって、同じモデル並列グループ内のモジュールに同じトレーニング データが含まれるようになります。 |
<<: ニッチから人気へ: 世界的な AI イノベーションが「ソフト」になった理由
>>: 二重あごをなくすコツがある。浙江大学の2000年代生まれの大学生が、ACM SIGGRAPHで発表した新しい美容アルゴリズムを開発
AI が人間の活動に取って代わるかどうかについての議論が激化するにつれ、データ サイエンティストは ...
[51CTO.comより引用] モバイルインターネット、モノのインターネット、ビッグデータ、人工知能...
世の中に不思議なことは何もありません。 「ボリューム」という言葉が最も重要視されるこの時代に、これま...
キャピタル グループは、1931 年、大恐慌の真っ只中にカリフォルニア州ロサンゼルスで設立され、現在...
FastTextは、Facebookが2016年にオープンソース化した単語ベクトル計算およびテキスト...
【51CTO.com クイック翻訳】1. はじめにテキスト要約は、自然言語処理 (NLP) の分野に...
【51CTO.comオリジナル記事】 1. 背景テキスト マッチングは、自然言語処理における中核的な...
[[406628]]仮想環境 (ALE、MuJoCo、OpenAI Gym) は、エージェントの制御...
人工知能や機械学習 (AI/ML) をトレーニングするために現実世界のデータを収集することは、時間が...
このテストでは合計20台の携帯電話が選ばれ、そのうち1台は海外製、残りの19台は国内トップ5の携帯電...
アルゴリズムとデータ構造は、常にプログラマーの基本的なスキルでした。データ構造の基本インフラストラク...
導入産業革命は一度しか起こらないが、私たちは今、人工知能 (AI) 革命という大きな革命の過程にある...
今の時代、どんな製品の開発にも実は学習プロセスが必要です。人工知能技術が急速に進歩したのは、まさに各...
背景Baiduは昨年11月にカスタマイズされた画像トレーニングサービスを開始しました(https:/...