ディープラーニングにおけるバッチ正規化の落とし穴

ディープラーニングにおけるバッチ正規化の落とし穴

[[191744]]

バッチ正規化は、ディープラーニングにおいて最近登場した効果的な手法です。その有効性は広く実証されており、研究やアプリケーションに急速に応用されています。この投稿は、読者がバッチ正規化とは何かを知っており、その仕組みについてある程度理解していることを前提としています。この概念を初めて知ったり、復習する必要がある場合は、次のリンク (http://blog.csdn.net/malefactor/article/details/51476961) でバッチ正規化の簡単な概要を参照してください。

この論文では、2 つの異なる方法を使用してニューラル ネットワークを実装します。各ステップで同じデータが入力されます。ネットワークには、まったく同じ損失関数、まったく同じハイパーパラメータ、まったく同じオプティマイザーがあります。その後、まったく同じ数の GPU でトレーニングが実行されます。結果として、一方のバージョンの分類精度は他方のバージョンよりも 2% 低く、このパフォーマンスの低下は非常に安定しているように見えます。

単純な MNIST と SVHN の分類問題を例に挙げてみましょう。

最初の実装では、MNIST データのバッチと SVHN データのバッチが抽出され、結合されてからネットワークに送られます。

2 番目の実装では、ネットワークのコピーが 2 つ作成され、重みが共有されます。 1 つのコピーには MNIST データが入力され、もう 1 つのコピーには SVHN データが入力されます。

どちらの実装でも、データの半分は MNIST で、残りの半分は SVHN であることに注意してください。さらに、2 番目の実装では重みを共有するため、2 つのモデルのパラメーターの数は同じになり、同じ方法で更新されます。

単純に考えると、これら 2 つのモデルのトレーニング中の勾配は同じになるはずです。これも事実です。しかし、バッチ正規化を追加すると状況は変わります。最初の実装では、同じデータ バッチに MNIST データと SVHN データの両方が含まれています。 2 番目の方法では、モデルは 2 つのバッチでトレーニングされます。1 つのバッチは MNIST データのみでトレーニングされ、もう 1 つのバッチは SVHN データのみでトレーニングされます。

この問題の原因は、トレーニング中に 2 つのネットワークがパラメータを共有する一方で、データ セットの平均と分散の移動平均も共有されるためです。このパラメータの更新は、両方のデータセットにも適用されます。 2 番目のアプローチでは、上部のネットワークは MNIST データからの平均と分散の推定値を使用してトレーニングされ、下部のネットワークは SVHN データからの平均と分散の推定値を使用してトレーニングされます。しかし、移動平均は 2 つのネットワーク間で共有されるため、移動平均は MNIST データと SVHN データの平均に収束します。

したがって、テスト時に、テスト セットで使用されるバッチ正規化のスケールと変換 (1 つのデータセットの平均) は、モデルが期待するもの (両方のデータセットの平均) とは異なります。テスト用の正規化がトレーニング用の正規化と異なる場合、モデルは次の結果を取得します。

このグラフは、5 つのランダム シードを使用した別の類似データセット (この例では MNIST または SVHN ではありません) での最高、中央値、最低のモデル パフォーマンスを示しています。重みを共有する 2 つのネットワークを使用すると、パフォーマンスが大幅に低下するだけでなく、出力の分散も増加します。

この問題は、単一のデータ ミニバッチがデータ分布全体を代表していない場合に発生します。つまり、入力をシャッフルすることを忘れずにバッチ正規化を使用するのは危険です。これは、最近人気の高い敵対的生成ネットワーク (GAN) でも非常に重要です。 GAN の識別器は通常、偽のデータと実際のデータの混合でトレーニングされます。識別器でバッチ正規化が使用されている場合、純粋に偽のデータのバッチと純粋に実際のデータのバッチを交互に使用するのは誤りです。各小バッチには、両方が均等に混合されている必要があります (それぞれ 50%)。

実際には、バッチ正規化変数を分離し、他の変数を共有するネットワーク構造を使用すると、最良の結果が得られることに注意してください。これは実装が複雑ですが、他の方法よりも確かに効果的です (下の図を参照)。

バッチ正規化:諸悪の根源

上記の問題を考慮して、著者は、可能であればバッチ正規化を使用しないという結論に達しました。

この結論はエンジニアリングの観点から分析されます。

一般的に、コードに問題がある場合、その理由は次の 2 つに限ります。

  1. 明らかに間違いです。たとえば、間違った変数名を入力したり、関数を呼び出すのを忘れたりした可能性があります。
  2. コードには、相互作用する他のコードの動作に対する暗黙の依存関係があり、それらの依存関係の一部が満たされていません。これらのエラーは、コードがどの条件に依存しているかを把握するのに通常長い時間がかかるため、より有害になることがよくあります。

これら両方の間違いは避けられません。 2 番目のタイプのエラーは、より単純な方法を使用し、既存のコードを再利用することで軽減できます。

バッチ正規化方法には、次の 2 つの基本的なプロパティがあります。

  1. トレーニング中、単一の入力 xi の出力はミニバッチ内の他の xj の影響を受けます。
  2. テスト時に、モデルの計算パスが変更されます。正規化にはミニバッチ平均ではなく移動平均が使用されるようになったためです。

これらの特性を持つ最適化方法は他にほとんどありません。これにより、バッチ正規化コードを実装する人は、入力ミニバッチが無相関であるか、トレーニング操作とテスト操作が同じであると想定しやすくなります。このアプローチに疑問を抱く人は誰もいないだろう。

もちろん、バッチ正規化は Java 正規化のブラック ボックス バージョンと考えることができますが、これは非常にうまく機能します。しかし、実際には抽象化には常に漏れがあり、バッチ正規化も例外ではなく、その特性により漏れがさらに生じやすくなります。

なぜ人々はバッチ正規化をあきらめないのでしょうか?

コンピュータ サイエンス コミュニティには、ダイクストラの「GoTo ステートメントは有害である」という有名な記事があります。この中で、ダイクストラは、goto 文はコードを読みにくくするので避けるべきであり、goto を使用するプログラムは goto 文なしで書き直すことができると主張しています。

著者は「バッチ正規化は有害である」という見解を述べたいと思っていますが、十分な理由が見つかりません。結局のところ、バッチ正規化は非常に便利です。

はい、バッチ正規化には問題があります。しかし、すべてを正しく行えば、モデルのトレーニングははるかに速くなります。バッチ正規化の論文が 1400 回以上引用されているのには、十分な理由があります。

バッチ正規化には多くの代替手段がありますが、それらにも独自の欠点があります。レイヤー正規化は、RNN で使用するとより効果的ですが、畳み込みレイヤーで使用すると問題が発生することがあります。重み正規化とコサイン正規化はどちらも比較的新しい正規化方法です。重み正規化の記事では、バッチ正規化が機能しないいくつかの問題に重み正規化を適用できると述べられています。しかし、これらの方法は今のところあまり使われておらず、おそらく時間の問題でしょう。レイヤー正規化、重み正規化、コサイン正規化はすべて、上記のバッチ正規化の問題に対処します。新しい問題に取り組んでいてリスクを負いたい場合には、これらの正規化方法を試してみることをお勧めします。結局、どの方法を使用する場合でも、ハイパーパラメータの調整が必要になります。一度調整すると、さまざまな方法間の違いは小さくなるはずです。

(勇気があれば、バッチ再正規化を試すこともできますが、テスト時には移動平均のみが使用されます。)

バッチ正規化の使用は、ディープラーニングにおける「悪魔の契約」と見なすことができます。得られるものは効率的なトレーニングですが、失うものは異常な結果(狂気)の可能性です。全員がこの契約書に署名します。

翻訳者メモ

「バッチ正規化は有害である」および「バッチ正規化の使用はできるだけ避ける」という著者の見解は、やや極端です。しかし、この記事で言及されているバッチ正規化の罠には注意する必要があります。バッチ正規化の有効性のため、多くのディープラーニング研究者はそれを「魔法のブラックボックス」として扱い、あらゆる可能な場所に適用しています。この単純で大雑把な方法は、トレーニング速度の向上に非常に効果的だからです。しかし、精度の低下をバッチ正規化に帰することは困難です。結局のところ、バッチ正規化によってトレーニングの精度が低下するとは誰も言及していません。

しかし、トレーニング中とテスト中にデータ セットが矛盾する状況は、実際には非常に一般的です。この問題は、翻訳者がトレーニング データ セットを人工的にシミュレートするときに発生します。バッチ正規化を使用する前に、次の問題を慎重に検討することをお勧めします。

  1. トレーニング データセットの各バッチのサンプルは平均化されていますか?
  2. トレーニング データセットのバッチ平均は、テスト中の移動平均と一致していますか?

それ以外の場合は、この記事で説明されている問題を回避するために、次の方法の 1 つ以上を使用する必要があります。

  1. バッチ平均化を確実にするためにトレーニング データセットをランダムにサンプリングします。
  2. 上記の問題を回避するには、記事の例のようにモデルを変更します。
  3. バッチ正規化の代わりにレイヤー正規化、重み正規化、またはコサイン正規化を使用します。
  4. 正規化方法は使用されません。

<<:  CNN の弱点を見つけ、MNIST の「ルーチン」に注意する

>>:  時空間アルゴリズム研究に基づくビジネス意思決定分析

ブログ    
ブログ    
ブログ    
ブログ    

推薦する

...

DeeCamp 2019は産学連携を促進するためにKuaishouとInnovation Worksを正式に立ち上げました

4月8日、イノベーションワークスが主催する「DeeCamp2019 人工知能サマートレーニングキャン...

AI スタートアップはどうすれば成功できるのでしょうか?ガートナー:「以下の点が不可欠」

[[430175]]デジタル変革の波を受けて、さまざまな新興技術が急速に応用され、普及してきました...

産業用AIoTが「新たな人気」となった4つの主な要因

最近発表された産業用人工知能および人工知能市場レポート 2021-2026 のデータによると、わずか...

...

トマシュ・トゥングズ: AI 組織が直面する 4 つの戦略的課題

編集者注: Tomasz Tunguz 氏は RedPoint のパートナーであり、スタートアップが...

...

人工知能システム:無制限の核融合反応を現実のものに

近年、研究者らはトカマクの停止や損傷の原因となる核分裂反応を研究している。核分裂反応を予測・制御でき...

AI CPUとMicrosoft Windows 12のリリースにより、2024年には世界のAI PC出荷台数が1,300万台を超えると予想

10月13日、Qunzhi Consultingが昨日発表した最新の調査によると、アルゴリズムとハー...

99行のコードでアナと雪の女王の特殊効果の太極拳の進化を実現

コンピュータシミュレーション技術の継続的な発展のおかげで、ますますリアルな現実世界をコンピュータで再...

2021 年にグラフ機械学習にはどのような新たなブレークスルーがあるでしょうか?マギル大学のポスドク研究員が分野の動向を整理

[[443041]]今年ももうすぐ終わり、あと3日で2021年も終わりです。さまざまなAI分野でも...

AI as a Service: AIとクラウドコンピューティングが出会うとき

競争で優位に立つために、ますます多くの企業が自社のアプリケーション、製品、サービス、ビッグデータ分析...

Google Brain Quoc、操作プリミティブから効率的なTransformerバリアントを検索するためのPrimerをリリース

[[426884]]モデルのパフォーマンスを向上させるには、パラメータを調整し、活性化関数を変更する...

機械学習の戦略原則: 基本プロセス、アルゴリズムフレームワーク、プロジェクト管理

著者: cooperyjli、Tencent CDG のデータ アナリスト機械学習は、データの収集、...