より多用途で効果的なAntの自社開発オプティマイザーWSAMがKDDオーラルに採用されました

より多用途で効果的なAntの自社開発オプティマイザーWSAMがKDDオーラルに採用されました

ディープ ニューラル ネットワーク (DNN) の一般化能力は、極値点の平坦性と密接に関係しています。そのため、より平坦な極値点を見つけて一般化能力を向上させる Sharpness-Aware Minimization (SAM) アルゴリズムが登場しました。本論文では、SAM の損失関数を再検討し、平坦性を正則化項として使用することでトレーニングの極値点の平坦性を改善する、より一般的で効果的な方法である WSAM を提案します。さまざまな公開データセットでの実験では、WSAM は、元のオプティマイザー、SAM およびそのバリアントと比較して、ほとんどの場合でより優れた一般化パフォーマンスを達成することが示されています。 WSAM は、Ant 内でもデジタル決済やデジタル金融などのさまざまなシナリオに広く採用されており、目覚ましい成果を上げています。この論文はKDD '23に口頭発表として採択されました。



  • 論文アドレス: https://arxiv.org/pdf/2305.15817.pdf
  • コードアドレス: https://github.com/intelligent-machine-learning/dlrover/tree/master/atorch/atorch/optimizers

ディープラーニング技術の発展により、高度に過剰パラメータ化された DNN は、CV や NLP などのさまざまな機械学習シナリオで大きな成功を収めています。過剰にパラメータ化されたモデルはトレーニング データに過剰適合する傾向がありますが、通常は一般化能力が優れています。一般化の謎はますます注目を集めており、ディープラーニングの分野で注目されている研究テーマとなっています。

最近の研究では、一般化能力は極値点の平坦性と密接に関係していることが示されています。つまり、損失関数の「ランドスケープ」内の極値点が平坦であれば、一般化誤差が小さくなります。シャープネスを考慮した最小化(SAM)[1]は、より平坦な極値点を見つける技術であり、最も有望な技術方向性の1つです。これは、CV、NLP、バイレベル学習などのさまざまな分野で広く使用されており、これらの分野における従来の最先端の方法を大幅に上回っています。

より平坦な最小値を探索するために、SAM はw における損失関数 L の平坦性を次のように定義します。

GSAM[2]は、が局所極値点のヘッセ行列の最大固有値の近似値であることを証明し、が平坦度(急勾配度)の有効な尺度であることを示している。ただし、これは最小点ではなく平坦な領域を見つけるためにのみ使用できるため、損失関数が損失値がまだ大きい点(周囲の領域が非常に平坦であるにもかかわらず)に収束する可能性があります。したがって、SAM は損失関数として採用します。これは、より平坦な表面を見つけることと、 と の間の損失値を小さくすることとの間の妥協点として考えることができます。ここで、両方に等しい重みが与えられます。

本論文では、 を正規化項として考慮し、 の構築を再考します。私たちは、WSAM (Weighted Sharpness-Aware Minimization) と呼ばれる、より一般的で効果的なアルゴリズムを開発しました。このアルゴリズムの損失関数は、重み付き平坦性項を正則化項として追加し、ハイパーパラメータが平坦性の重みを制御します。方法セクションでは、損失関数をガイドして、より平坦な、またはより小さな極値を見つける方法を示しました。私たちの主な貢献は次のようにまとめられます。


  • 我々は、平坦性を正則化項として扱い、異なるタスクに異なる重みを与える WSAM を提案します。現在のステップの平坦性を正確に反映することを目的として、更新式の正則化項を処理するための「重み分離」手法を提案します。ベース オプティマイザーが SGD ではない場合 (SGDM や Adam など)、WSAM は形式が SAM と大きく異なります。アブレーション実験では、この技術によりほとんどの場合に結果を改善できることが示されています。
  • 公開データセットの一般的なタスクにおける WSAM の有効性を検証します。実験結果によると、SAM とその変種と比較して、WSAM はほとんどの場合で一般化パフォーマンスが優れています。

前提条件

SAMは式(1)で定義されるミニマックス最適化問題を解く手法である。

まず、SAMはwの周りの1次テイラー展開を使用して内部最大化問題を近似します。つまり

次に、SAM は の近似勾配を取って w を更新します。つまり、

2 番目の近似は計算を高速化することです。その他の勾配ベースのオプティマイザー (ベースオプティマイザーと呼ばれる) は、SAM の一般的なフレームワークに組み込むことができます。詳細については、アルゴリズム 1 を参照してください。アルゴリズム 1 のとを変更することで、SGD、SGDM、Adam などの異なる基本オプティマイザーを取得できます (表 1 を参照)。ベースオプティマイザがSGDの場合、アルゴリズム1はSAM論文[1]の元のSAMにフォールバックすることに注意してください。

方法の紹介

WSAM設計の詳細

ここでは、正規損失と平坦性項で構成される正式な定義を示します式(1)から、

= 0 の場合従来の損失に退化します。 = 1/2 の場合、と同等になります。 > 1/2 の場合、平坦性に重点が置かれるため、SAM よりも損失値が小さい点を見つけるよりも、曲率が小さい点を見つける方が簡単です。逆もまた同様です。

異なるベースオプティマイザを使用した WSAM の一般的なフレームワークは、異なる合計を選択することで実装できます。アルゴリズム 2 を参照してください。たとえば、およびのとき SGD を基本オプティマイザとして使用した WSAM が得られます。アルゴリズム 3 を参照してください。ここでは、「重み分離」手法を採用しています。つまり、平坦性項は、勾配を計算し重みを更新するための基本オプティマイザーと統合されるのではなく、独立して計算されます (アルゴリズム 2 の 7 行目の最後の項)。このように、正規化の効果は、追加情報なしで現在のステップの平坦性のみを反映します。比較のために、アルゴリズム 4 では、「重み分離」なしの WSAM (Coupled-WSAM と呼ばれる) を示します。たとえば、基本オプティマイザーが SGDM の場合、Coupled-WSAM の正規化項は平坦度の指数移動平均になります。実験セクションで示したように、「重み分離」はほとんどの場合に一般化パフォーマンスを向上させることができます。

図 1 は、さまざまな値での WSAM 更新プロセスを示しています。 <1/2 のときは間であり、 が増加するにつれて徐々にずれていきます

簡単な例

WSAM における γ の効果と利点をよりわかりやすく説明するために、2 次元で簡単な例を設定しました。図2に示すように、損失関数は左下隅に比較的不均一な極値点(位置:(-16.8、12.8)、損失値:0.28)を持ち、右上隅に平坦な極値点(位置:(19.8、29.9)、損失値:0.36)を持っています。損失関数は次のように定義されます: 、ここで は単変量ガウスモデルと 2 つの正規分布、つまり、ここでおよび の間の KL ダイバージェンスです

基本オプティマイザーとしてモメンタム 0.9 の SGDM を使用し、 SAM と WSAM に対して α = 2 を設定します。初期点 (-6, 10) から始めて、学習率 5 を使用して 150 ステップ内で損失関数が最適化されます。 SAM は損失は少ないが平坦性は低い極値に収束し、 = 0.6 の WSAM でも同じことが当てはまります。ただし、 =0.95 では損失関数が平坦な極値に収束し、より強力な平坦性正規化が役割を果たしていることを示しています。

実験

WSAMの有効性を検証するために、さまざまなタスクで実験を実施しました。

画像分類

まず、Cifar10 および Cifar100 データセットで WSAM がモデルをゼロからトレーニングする場合の効果を調べます。選択したモデルには、ResNet18 と WideResNet-28-10 が含まれます。 ResNet18 と WideResNet-28-10 に対してそれぞれ 128 と 256 の定義済みバッチ サイズを使用して、Cifar10 と Cifar100 でモデルをトレーニングします。ここで使用される基本オプティマイザーは、モメンタム 0.9 の SGDM です。 SAM[1]の設定によれば、各基本オプティマイザによって実行されるエポックの数はSAMクラスオプティマイザの2倍である。両方のモデルを 400 エポック (SAM のようなオプティマイザーの場合は 200 エポック) トレーニングし、コサイン スケジューラを使用して学習率を低下させました。ここでは、カットアウトや AutoAugment などの他の高度なデータ拡張方法は使用しませんでした。

どちらのモデルでも、ジョイント グリッド サーチを使用して、基本オプティマイザーの学習率と重み減衰係数を決定し、後続の SAM のようなオプティマイザー実験ではそれらを変更しません。学習率と重み減衰係数の探索範囲はそれぞれ{0.05, 0.1}と{1e-4, 5e-4, 1e-3}です。すべての SAM タイプのオプティマイザーにはハイパーパラメータ(近傍サイズ) があるため、次に SAM オプティマイザーで最適なものを検索し、他の SAM タイプのオプティマイザーに同じ値を使用します。検索範囲は{0.01、0.02、0.05、0.1、0.2、0.5}です。最後に、元の論文で推奨されている範囲内で、他の SAM タイプのオプティマイザーに固有のハイパーパラメータを検索します。 GSAM[2]の場合、範囲{0.01、0.02、0.03、0.1、0.2、0.3}で検索します。 ESAM[3]については、{0.4, 0.5, 0.6}の範囲で検索し、{0.4, 0.5, 0.6}の範囲で検索し、{0.4, 0.5, 0.6}の範囲で検索しました。 WSAM の場合、範囲 {0.5、0.6、0.7、0.8、0.82、0.84、0.86、0.88、0.9、0.92、0.94、0.96} で検索します。異なるランダムシードを使用して実験を 5 回繰り返し、平均誤差と標準偏差を計算しました。単一の NVIDIA A100 GPU で実験を実施します。各モデルのオプティマイザーハイパーパラメータは表3にまとめられています。

表2は、異なる最適化手法によるCifar10とCifar100のテストセットにおけるResNet18とWRN-28-10のトップ1エラー率を示しています。基本オプティマイザーと比較すると、SAM オプティマイザーの効果は大幅に向上しています。同時に、WSAM は他の SAM オプティマイザーよりも大幅に優れています。

ImageNetに関する追加トレーニング

さらに、ImageNet データセットで Data-Efficient Image Transformers ネットワーク構造を使用して実験を実施しました。事前にトレーニングした DeiT ベースのチェックポイントを復元し、3 つのエポックにわたってトレーニングを継続しました。モデルはバッチ サイズ 256 でトレーニングされ、基本オプティマイザーはモメンタム 0.9 の SGDM、重み減衰係数は 1e-4、学習率は 1e-5 です。 4 つの NVIDIA A100 GPU で実行を 5 回繰り返し、平均誤差と標準偏差を計算しました。

{0.05, 0.1, 0.5, 1.0,⋯, 6.0} の中で最適なSAMを検索します。最良= 5.5 は、他の SAM クラス オプティマイザーに直接使用されます。その後、{0.01、0.02、0.03、0.1、0.2、0.3} の範囲で GSAM のベストを、また 0.80 から 0.98 の範囲で WSAM のベストを、ステップ サイズ 0.02 で検索します

モデルの初期のトップ 1 エラー率は 18.2% です。さらに 3 つのエポック後のエラー率は表 4 に示されています。 3 つの SAM のようなオプティマイザーの間には大きな違いは見られませんが、いずれもベース オプティマイザーよりもパフォーマンスが優れており、より平坦な極値を見つけることができ、より優れた一般化能力を備えていることがわかります。

ラベルノイズに対する堅牢性

以前の研究[1, 4, 5]に示されているように、SAM型オプティマイザーは、トレーニングセットにラベルノイズがある場合でも優れた堅牢性を示します。ここでは、WSAM の堅牢性を SAM、ESAM、GSAM と比較します。 Cifar10 データセットで ResNet18 を 200 エポックトレーニングし、ノイズ レベルが 20%、40%、60%、80% の対称ラベル ノイズを挿入します。ベースオプティマイザーとしてモーメンタム 0.9、バッチサイズ 128、学習率 0.05、重み減衰係数 1e-3、学習率を減衰させるコサインスケジューラを備えた SGDM を使用します。各ラベルのノイズ レベルについて、{0.01、0.02、0.05、0.1、0.2、0.5} の範囲で SAM をグリッド検索して共通値を決定します。次に、他のオプティマイザー固有のハイパーパラメータを個別に検索して、最適な一般化パフォーマンスを見つけました。結果を再現するために必要なハイパーパラメータを表5に示します。堅牢性テストの結果を表 6 に示します。WSAM は一般に SAM、ESAM、GSAM よりも堅牢です。

幾何学の影響を探る

SAMのような最適化装置はASAM [4]やFisher SAM [5]などの技術と組み合わせて、探索近傍の形状を適応的に調整することができます。 Cifar10 上の WRN-28-10 で実験を行い、適応型法とフィッシャー情報法をそれぞれ使用した場合の SAM と WSAM のパフォーマンスを比較して、探索領域の形状が SAM のような最適化ツールの一般化パフォーマンスにどのように影響するかを理解します。

パラメータとを除いて、画像分類では設定を再利用します。これまでの研究[4, 5]によれば、ASAMとFisher SAMは通常より大きい。 {0.1、0.5、1.0、…、6.0} の中で最適なものを検索し、ASAM と Fisher SAM の両方で最適なのは5.0 です。その後、ステップサイズ 0.02 で 0.80 から 0.94 の間で WSAM のベストを検索し、両方の方法のベストは0.88 でした。

驚くべきことに、表 7 に示すように、ベースライン WSAM は複数の候補の間でも優れた一般化を示しています。したがって、固定ベースラインで WSAM を直接使用することをお勧めします。

アブレーション実験

このセクションでは、WSAM における「重量分離」技術の重要性をより深く理解するために、アブレーション実験を行います。 WSAM の設計詳細で説明したように、「重み分離」のない WSAM バリアント (アルゴリズム 4) の結合 WSAM を元の方法と比較します。

結果は表8に示されています。ほとんどの場合、Coupled-WSAM は SAM よりも優れた結果を生み出し、WSAM はほとんどの場合にさらに結果を改善し、「重み分離」技術の有効性を証明します。

極値点分析

ここでは、WSAM オプティマイザーと SAM オプティマイザーによって検出された極端なポイントの違いを比較することで、WSAM オプティマイザーの理解をさらに深めます。極点における平坦さ(急勾配)は、ヘッセ行列の最大固有値によって記述できます。固有値が大きいほど、平坦性は低下します。この最大固有値を計算するには、Power Iteration アルゴリズムを使用します。

表 9 は、SAM オプティマイザーと WSAM オプティマイザーによって検出された極端なポイントの違いを示しています。バニラ オプティマイザーは損失値が小さいが平坦性が低い極端な点を見つけるのに対し、SAM は損失値が大きいがより平坦な極端な点を見つけるため、一般化パフォーマンスが向上することがわかります。興味深いことに、WSAM によって発見された極値ポイントは、SAM よりも損失値がはるかに小さいだけでなく、SAM に非常に近い平坦性も持っています。これは、極値点を見つけるプロセスにおいて、WSAM がより平坦な領域を検索しようとしながら、より小さな損失値を確保することを優先していることを示しています。

ハイパーパラメータ感度

SAM と比較して、WSAM には平坦性 (急峻さ) 項のサイズをスケーリングするための追加のハイパーパラメータがあります。ここでは、このハイパーパラメータに対する WSAM の一般化パフォーマンスの感度をテストします。 Cifar10 および Cifar100 で WSAM を使用して、幅広い値を使用して ResNet18 および WRN-28-10 モデルをトレーニングしました。図 3 に示すように、結果は WSAM がハイパーパラメータの選択に影響を受けないことを示しています。また、WSAM の最適な一般化パフォーマンスは、ほぼ常に 0.8 ~ 0.95 の間であることもわかりました。

<<:  何百万人ものネットユーザーがDALL-E 3の新しいゲームプレイを視聴しました!アイアンマンとテスラはどれも「ヒット」、強迫性障害に優しい、ブロガーがヒントを共有

>>: 

ブログ    
ブログ    

推薦する

ハーバード大学とMITが協力し、新型コロナウイルスに遭遇すると自動的に光るスマートマスクを開発

[[326611]] 「新型コロナウイルスにさらされると、マスクが自動的に点灯し、検査員に警告を発し...

GenAI 時代のデータ ガバナンスの青写真

ML と GenAI の世界に深く入り込むにつれて、データ品質への重点が重要になります。 KMS T...

...

こんにちは、音声認識について学びましょう!

[51CTO.com からのオリジナル記事] 音声認識は自動音声認識とも呼ばれ、人間の音声に含まれ...

ネットユーザーの83%を騙した!画像生成の頂点、DALL-E 2 は実際にチューリングテストに合格したのか?

数日前、休暇中だったネットユーザーが「DALL-E 2」にアクセスできたことを知った。 2秒間考えた...

データ管理はAI革命の最大の課題となるでしょうか?

最新のデータへの投資は人工知能の拡張を成功させる上で重要ですが、調査によると、企業の半数がコストの障...

この記事では機械学習における3つの特徴選択手法を紹介します。

機械学習では特徴を選択する必要があり、人生でも同じではないでしょうか?特徴選択とは、利用可能な多数の...

...

AIが気候変動に効果的に対抗する方法

人工知能(AI)の活用は気候変動との闘いに貢献することができます。既存の AI システムには、天気を...

AI不正対策!ディープフェイク音声・動画検出技術がCESでデビュー、精度は90%以上

真実とは程遠いが、アメリカの消費者向けニュースおよびビジネスチャンネルCNBCのロゴ入りのビデオでは...

アルゴリズム取引におけるビッグデータ分析の活用

ウォーレン・バフェットの資産が 5000G あることをご存知ですか? 反対派や懐疑派の意見に反して、...

80億人民元を超える資金で医療AIは「V字カーブ」を描いている

[[373863]] 「人工知能は将来の生産性の中核である」という見解に疑問を抱く人はほとんどいませ...

ヘルスケアにおける AI: 注目すべき 3 つのトレンド

COVID-19 パンデミック、メンタルヘルス危機、医療費の高騰、人口の高齢化により、業界のリーダ...

ロボットが家庭に入り、人工知能の夢はもはや高価ではない

[[221538]]人工知能とは何ですか? 「第一次産業革命における蒸気機関、第二次産業革命における...