大きなモデルもスライスできます。Microsoft SliceGPTはLLAMA-2の計算効率を大幅に向上させます。

大きなモデルもスライスできます。Microsoft SliceGPTはLLAMA-2の計算効率を大幅に向上させます。

大規模言語モデル (LLM) には通常、数十億のパラメータがあり、数兆のトークンのデータを使用してトレーニングされます。このようなモデルのトレーニングと展開のコストは非常に高くなります。したがって、計算要件を削減するために、さまざまなモデル圧縮技術がよく使用されます。

一般的に、これらのモデル圧縮手法は、蒸留、テンソル分解 (低ランク因数分解を含む)、プルーニング、量子化の 4 つのカテゴリに分類できます。その中で、プルーニング手法は以前から存在していましたが、多くの手法では、パフォーマンスを維持するためにプルーニング後にリカバリの微調整 (RFT) が必要となり、プロセス全体のコストがかかり、拡張が困難になります。

この問題を解決するために、ETH チューリッヒとマイクロソフトの研究者は SliceGPT と呼ばれる方法を提案しました。 SliceGPT の中心的なアイデアは、重み行列の行と列を削除して、モデルのパフォーマンスを維持しながらネットワークの埋め込み次元を減らすことです。

研究者らは、SliceGPT を使用すると、単一の GPU を使用してわずか数時間で大規模なモデルを圧縮でき、RFT がなくても生成および下流のタスクで競争力のあるパフォーマンスを維持できると述べています。現在、この論文は ICLR 2024 に採択されています。


  • 論文タイトル: SLICEGPT: 行と列を削除して大規模言語モデルを圧縮する
  • 論文リンク: https://arxiv.org/pdf/2401.15024.pdf

プルーニング方法は、LLM 内の重み行列の特定の要素をゼロに設定し、(オプションで)行列の周囲の要素を更新して補正することによって機能します。結果はスパース パターンであり、ニューラル ネットワークの順方向パスに必要な行列乗算中に一部の浮動小数点演算をスキップできることを意味します。

計算速度の相対的な向上は、スパース性レベルとスパース性パターンに依存します。より構造化されたスパース性パターンは、より大きな計算ゲインをもたらします。他のプルーニング方法とは異なり、SliceGPT は重みマトリックスの行または列全体をプルーニング (切り取り) します。チョッピングの前に、ネットワーク上で変換を実行し、予測は変更せずにチョッピング プロセスでわずかな変更を許可します。

その結果、重み行列が小さくなり、ニューラル ネットワーク ブロック間で渡される信号が小さくなり、ニューラル ネットワークの埋め込み次元が減少します。

下の図 1 は、SliceGPT 法と既存のスパース法を比較したものです。

著者らは、広範囲にわたる実験を通じて、SliceGPT は LLAMA-2 70B、OPT 66B、および Phi-2 モデルのモデル パラメーター (埋め込みを含む) を最大 25% 削除しながら、それぞれ高密度モデルのゼロ ショット タスク パフォーマンスの 99%、99%、および 90% を維持できることを発見しました。

SliceGPT で処理されたモデルは、追加のコード最適化なしで、より少ない GPU でより高速に実行できます。24 GB のコンシューマー グレードの GPU では、著者らは LLAMA-2 70B の合計推論計算を高密度モデルの 64% に削減しました。40 GB の A100 GPU では、66% に削減しました。

さらに、彼らは、SliceGPTを可能にする、Transformer ネットワークにおける計算不変性という新しい概念を提案しました。

SliceGPT の詳しい説明

SliceGPT アプローチは、Transformer アーキテクチャに固有の計算不変性に依存します。つまり、1 つのコンポーネントの出力に直交変換を適用し、次のコンポーネントでそれを元に戻すことができます。著者らは、ネットワーク ブロック間で実行される RMSNorm 操作は変換に影響を与えないことに注目しています。これらの操作は可換です。

この論文では、まず RMSNorm 接続を持つ Transformer ネットワークで不変性を実現する方法を紹介し、次に LayerNorm 接続でトレーニングされたネットワークを RMSNorm に変換する方法について説明します。次に、主成分分析 (PCA) を使用して各レイヤーの変換を計算し、パッチ間の信号を主成分に投影する方法を導入しました。最後に、マイナー主成分を削除すると、ネットワークの行または列が切り取られることを示します。

Transformer ネットワークの計算不変性

Q が直交行列を表すとします。


  • ベクトル x に Q を掛けてもベクトルのノルムは変化しないことに注意してください。これは、この作業では Q の次元が常に変換器 D の埋め込み次元と一致するためです。

X_ℓがトランスフォーマーブロックの出力であると仮定します。RMSNormによって処理された後、RMSNorm(X_ℓ)の形式で次のブロックに入力されます。 RMSNormの前に直交行列Qを持つ線形層を挿入し、RMSNormの後にQ^⊤を挿入した場合、信号行列の各行はQで乗算され、正規化され、Q^⊤で乗算されるため、ネットワークは変更されません。以下です:

ここで、ネットワーク内の各アテンションまたは FFN ブロックは入力と出力に対して線形演算を実行するため、追加演算 Q はモジュールの線形層に吸収されます。ネットワークには残差接続が含まれているため、Q は以前のすべてのレイヤー (埋め込みまで) と後続のすべてのレイヤー (LM ヘッドまで) の出力にも適用する必要があります。

不変関数とは、入力が変更されても出力が変化しない関数です。この例では、結果を変更せずに任意の直交変換 Q をトランスフォーマーの重みに適用できるため、計算は任意の変換状態で実行できます。著者らはこれを計算不変性と呼び、以下の定理で定義しています。

定理1:およびをRMSNorm接続トランスフォーマーネットワークのl番目の線形層の重み行列とし、およびを対応するバイアス(存在する場合)、W_embdおよびW_headを埋め込み行列およびヘッド行列とします。 Q を次元 D の直交行列とすると、次のネットワークは元のトランスフォーマー ネットワークと同等になります。


入力バイアスとヘッダーバイアスをコピーします。

アルゴリズム 1 により、変換されたネットワークによって計算された結果が元のネットワークと同じであることが証明されます。

LayerNorm TransformerはRMSNormに変換できます

Transformer ネットワークの計算不変性は、RMSNorm 接続ネットワークにのみ適用されます。 LayerNorm を使用してネットワークを処理する前に、著者らはまず LayerNorm の線形ブロックを隣接するブロックに吸収し、ネットワークを RMSNorm に変換します。

図 3 は、Transformer ネットワーク (図 2 を参照) のこの変換を示しています。各ブロックでは、出力行列 W_out に平均減算行列 M を乗算します。この行列は、後続の LayerNorm での平均減算を考慮します。入力行列 W_in は、前の LayerNorm ブロックのスケールで事前に乗算されます。埋め込み行列 W_embd は平均減算する必要があり、W_head は最後の LayerNorm の比率に再スケーリングする必要があります。これは単に操作の順序を変更するだけであり、ネットワーク出力には影響しません。

各ブロックの変形

トランスフォーマー内の各 LayerNorm が RMSNorm に変換されたので、任意の Q を選択してモデルを変更できます。著者らの当初の計画は、モデルから信号を収集し、これらの信号を使用して直交行列を構築し、ネットワークの一部を削除することでした。彼らはすぐに、ネットワーク内の異なるブロックからの信号が揃っていないことを発見し、各ブロックに異なる直交行列 Q_ℓ を適用する必要があることに気付きました。

各ブロックで使用される直交行列が異なっていても、モデルは変わりません。証明は、アルゴリズム 1 の 5 行目を除いて、定理 1 と同じです。ここで、残差接続とブロックの出力は同じ回転を持つ必要があることがわかります。この問題に対処するために、著者らは残差に対して線形変換を実行することによって残差接続を修正します。

図 4 は、残差接続に対して追加の線形演算を実行することによって、異なるブロックに異なる回転を適用する方法を示しています。重み行列の変更とは異なり、これらの追加操作は事前に計算できず、モデルに小さな (D × D) オーバーヘッドが追加されます。それでも、モデルを切り取るためにこれらの操作は依然として必要であり、全体的な速度が確かに速くなることがわかります。

行列Q_ℓを計算するために、著者らはPCAを使用しました。トレーニング セットからキャリブレーション データセットを選択し、それをモデルで実行し (LayerNorm 操作を RMSNorm に変換した後)、そのレイヤーの直交行列を抽出します。より正確には、変換されたネットワークの出力を使用して、次のレイヤーの直交行列を計算する場合です。具体的には、キャリブレーションデータセットのi番目のシーケンスに対するℓ番目のRMSNormモジュールの出力がである場合、次を計算します。


Q_ℓをC_ℓの固有ベクトルとし、固有値の降順で並べます。

切除

主成分分析の目的は通常、データ行列 X を取り、低次元表現 Z と近似再構成を計算することです。

ここで、Q は の固有ベクトル、D は行列の左側のいくつかの列を削除するために使用される D × D の小さな削除行列(D × D コロケーション行列の D の小さな列を含む)です。再構成は、QD が を最小化する線形マップであるという意味で L_2 最適です。

ブロック間信号行列XにPCAを適用する際、著者らはN×D信号行列を具体化することはなく、代わりにこの行列を構築する前後の操作に削除行列Dを適用しました。上記の操作では、この行列に Q が乗算されています。著者はW_inの行とW_outおよびW_embdの列を削除しました。また、残差接続に挿入された行列の行と列も削除しました (図 4 を参照)。

実験結果

タスクを生成する

著者らは、WikiText-2 データセットで SliceGPT と SparseGPT による剪定を行った後、さまざまなサイズの OPT および LLAMA-2 モデル ファミリのパフォーマンスを評価しました。表 1 は、さまざまなレベルの剪定後にモデルが保持する複雑さを示しています。 LLAMA-2 モデルと比較すると、SliceGPT は OPT モデルに適用した場合に優れたパフォーマンスを示し、これはモデル スペクトルの分析に基づく著者の推測と一致しています。

SliceGPT のパフォーマンスは、モデルのサイズが大きくなるにつれて向上します。 SparseGPT 2:4 モードは、すべての LLAMA-2 ファミリ モデルで 25% 削減すると、SliceGPT よりもパフォーマンスが低下します。 OPT の場合、2.7B モデルを除くすべてのモデルにおいて、30% 切除率のモデルのスパース性が 2:4 のモデルよりも優れていることがわかります。

ゼロショットタスク

著者らは、PIQA、WinoGrande、HellaSwag、ARC-e、ARCc の 5 つのタスクを使用して、ゼロショット タスクにおける SliceGPT のパフォーマンスを評価しました。評価では、LM Evaluation Harness をデフォルトのパラメータとして使用しました。

図 5 は、上記のタスクにおける剪定済みモデルの平均スコアを示しています。図の上段はWikiText-2におけるSliceGPTの平均精度を示しており、下段はAlpacaにおけるSliceGPTの平均精度を示しています。結果からは、生成タスクの場合と同様の結論が観察されます。つまり、OPT モデルは LLAMA-2 モデルよりも圧縮に適応しやすく、モデルが大きくなるほど、剪定後の精度の低下は目立たなくなります。

著者らは、Phi-2 などの小規模モデルで SliceGPT の有効性をテストしました。カスタマイズされた Phi-2 モデルのパフォーマンスは、カスタマイズされた LLAMA-2 7B モデルと同等でした。最大の OPT および LLAMA-2 モデルは効果的に圧縮でき、SliceGPT は 66B OPT モデルから 30% を削除するときに数パーセント ポイントの損失のみで圧縮できます。

著者らは修復微調整(RFT)実験も実施した。 LoRA を使用して、トリミングされた LLAMA-2 および Phi-2 モデルに対して少量の RFT を実行しました。

実験結果を図6に示します。 WikiText-2 データセットと Alpaca データセットの RFT の結果には大きな違いがあり、モデルは Alpaca データセットでより優れたパフォーマンスを示していることがわかります。著者らは、この違いの理由は、Alpaca データセットのタスクがベンチマーク タスクに近いためだと考えています。

最大の LLAMA-2 70B モデルの場合、30% 削減してから RFT を実行した後、Alpaca データセットの最終的な平均精度は 74.3% になり、元の高密度モデルの精度は 76.6% になりました。削減されたモデル LLAMA-2 70B は約 516 億個のパラメータを保持し、スループットが大幅に向上します。

また著者らは、Phi-2 は WikiText-2 データセットの剪定モデルから元の精度を回復できなかったが、Alpaca データセットでは数パーセントの精度を回復できたことも発見しました。 Alpaca データセットで 25% 削減され RFT が適用された Phi-2 の平均精度は 65.2% ですが、元の密なモデルの精度は 72.2% です。削減されたモデルは 2.2B のパラメータを保持し、2.8B モデルの精度の 90.3% を保持します。これは、小さな言語モデルでも効果的に剪定できることを示しています。

ベンチマークスループット

従来のプルーニング手法とは異なり、SliceGPT は行列 X に (構造化された) スパース性を導入します。つまり、X の列全体が切り捨てられ、埋め込み次元が削減されます。このアプローチは、SliceGPT 圧縮モデルの計算の複雑さ (浮動小数点演算の数) を高めるだけでなく、データ転送効率も向上させます。

80GB H100 GPU では、シーケンス長を 128 に設定し、GPU メモリが使い果たされるかスループットが低下するまで、シーケンス長をバッチで 2 倍にして最大スループットを見つけます。著者らは、25% および 50% 削減されたモデルのスループットを、H100 GPU 上の 80 GB の元の高密度モデルのスループットと比較しました。 25% 削減されたモデルでは、スループットが最大 1.55 倍向上しました。

50% のプルーニングにより、最大モデルでは 1 つの GPU を使用した場合にスループットが 3.13 倍と 1.87 倍に大幅に向上します。これは、GPU の数が固定されている場合、プルーニングされたモデルのスループットが、元の高密度モデルのそれぞれ 6.26 倍と 3.75 倍に達することを示しています。

50% のプルーニング後、WikiText2 に保存された SliceGPT の複雑さは SparseGPT 2:4 よりも劣りますが、そのスループットは SparseGPT をはるかに上回ります。サイズが 13B のモデルの場合、メモリが少ないコンシューマー GPU 上の小型モデルでもスループットが向上する可能性があります。

推論の時間

著者らは、SliceGPT を使用して圧縮されたモデルのエンドツーエンドの実行時間も研究しました。表 2 は、Quadro RTX6000 および A100 GPU 上の OPT 66B モデルと LLAMA-2 70B モデルの単一トークンを生成するのに必要な時間を比較しています。 RTX6000 GPU では、モデルを 25% 削減した後、推論速度が 16 ~ 17% 向上し、A100 GPU では速度が 11 ~ 13% 向上したことがわかります。オリジナルの高密度モデルと比較すると、LLAMA-2 70B の場合、RTX6000 GPU を使用するために必要な計算量は 64% 削減されます。著者らは、この改善は、元の重み行列をより小さなものに置き換え、密なカーネルを使用するという SliceGPT のアプローチによるものだと考えています。これは他のプルーニング方式では不可能です。

著者らは、執筆時点ではベースラインの SparseGPT 2:4 ではエンドツーエンドのパフォーマンス向上を達成できないと述べています。代わりに、トランスフォーマー層の各操作の相対時間を比較することで、SliceGPT と SparseGPT 2:4 を比較しました。大規模なモデルの場合、SliceGPT (25%) は、スピードアップと複雑度の点で SparseGPT (2:4) と競争力があることが分かりました。

コストを計算する

すべての LLAMA-2、OPT、Phi-2 モデルは、単一の GPU で 1 ~ 3 時間でシャーディングできます。表3に示すように、微調整を復元することで、すべてのLMを1〜5時間以内に圧縮できます。

詳細については、原文論文を参照してください。

<<:  画像を外国語として扱うKuaishouと北京大学のマルチモーダル大規模モデルはDALLE-3に匹敵する

>>:  マスク氏:ニューラリンクが初めて人体にチップを埋め込み、製品化へ

ブログ    
ブログ    

推薦する

「新しいインフラ」に求められるAI人材のギャップをどう埋めるか

「新インフラ」がホットワードとなり、その重要な構成要素として人工知能に大きな期待が寄せられている。 ...

エレクトロニック・アーツは、人工知能によってゲームキャラクターがよりリアルになると述べている

どのビデオゲームでも、キャラクターが予想外の行動をとって没入感を壊してしまう瞬間が必ずあります。もし...

インメモリコンピューティング技術に基づく人工知能チップが利用可能:パフォーマンスは数十から数百倍高速

[[249742]]人工知能システム用の新しいコンピュータチップが利用可能になりました。プリンストン...

集める! 2017 年の主要な AI イベントを総ざらい!(動画付き)

[[219484]] 2017 年に 1 年間眠っていたのに、突然目が覚めて、今年世界で最も誇るべ...

メタバースはヘリコプターの飛行に役立ちますか? ALIASシステムはブラックホークを30分間フル稼働させる

無人ヘリコプター自体は目新しいものではないが、現在市販されている無人ヘリコプターは、第一に誰かが遠隔...

海外メディア:ロボットは人間の生活を変え、雇用や結婚のパターンに影響を与える

[[442070]]レファレンス・ニュース・ネットワークは12月26日、ドイツのフランクフルター・ア...

世界をリセットし、すべてをつなげる5Gは人工知能にどんな機会と課題をもたらすのか

[[274397]] 5G時代は人工知能にどのような新たな機会をもたらすのでしょうか?人工知能と5G...

Sitechiのスマートオペレーションプラットフォームは、スマートシティが4.0時代に入ることを支援します

現在、中国ではデジタル経済の波が高まっています。情報技術を都市計画や建設とどのように融合させ、都市情...

AIが疫病と戦う:百度がマスク顔検出・分類モデルをオープンソース化

仕事に戻るにあたり、各地域はどのように流行を予防すべきでしょうか?人工知能技術は、新型コロナウイルス...

少なくとも 8 つのトップカンファレンス論文! NvidiaのLLM研究科学者の求人数は非常に多く、元Google Brainの科学者を驚かせるほどである。

機械学習の分野で仕事を見つけるのはどれくらい難しいですか? NVIDIA の大規模モデル研究科学者の...

AI、機械学習、RPA業界への期待

毎年、IT 業界メディアの eWEEK では、新製品、革新的なサービス、開発動向など、IT 業界の今...

顔認識防止技術の登場により、顔をスキャンするのはまだ安全でしょうか?

現在、より成熟し、広く使用されているインテリジェント テクノロジーにはどのようなものがありますか? ...

「スカイアイ」が駐車問題を解決し、人工知能が都市統治を強化

新華網、北京、3月4日、タイトル:「スカイアイ」が駐車の難しさを解決し、人工知能が都市統治を強化新華...