ディープラーニングにおけるバッチ正規化の落とし穴

ディープラーニングにおけるバッチ正規化の落とし穴

[[191744]]

バッチ正規化は、ディープラーニングにおいて最近登場した効果的な手法です。その有効性は広く実証されており、研究やアプリケーションに急速に応用されています。この投稿は、読者がバッチ正規化とは何かを知っており、その仕組みについてある程度理解していることを前提としています。この概念を初めて知ったり、復習する必要がある場合は、次のリンク (http://blog.csdn.net/malefactor/article/details/51476961) でバッチ正規化の簡単な概要を参照してください。

この論文では、2 つの異なる方法を使用してニューラル ネットワークを実装します。各ステップで同じデータが入力されます。ネットワークには、まったく同じ損失関数、まったく同じハイパーパラメータ、まったく同じオプティマイザーがあります。その後、まったく同じ数の GPU でトレーニングが実行されます。結果として、一方のバージョンの分類精度は他方のバージョンよりも 2% 低く、このパフォーマンスの低下は非常に安定しているように見えます。

単純な MNIST と SVHN の分類問題を例に挙げてみましょう。

最初の実装では、MNIST データのバッチと SVHN データのバッチが抽出され、結合されてからネットワークに送られます。

2 番目の実装では、ネットワークのコピーが 2 つ作成され、重みが共有されます。 1 つのコピーには MNIST データが入力され、もう 1 つのコピーには SVHN データが入力されます。

どちらの実装でも、データの半分は MNIST で、残りの半分は SVHN であることに注意してください。さらに、2 番目の実装では重みを共有するため、2 つのモデルのパラメーターの数は同じになり、同じ方法で更新されます。

単純に考えると、これら 2 つのモデルのトレーニング中の勾配は同じになるはずです。これも事実です。しかし、バッチ正規化を追加すると状況は変わります。最初の実装では、同じデータ バッチに MNIST データと SVHN データの両方が含まれています。 2 番目の方法では、モデルは 2 つのバッチでトレーニングされます。1 つのバッチは MNIST データのみでトレーニングされ、もう 1 つのバッチは SVHN データのみでトレーニングされます。

この問題の原因は、トレーニング中に 2 つのネットワークがパラメータを共有する一方で、データ セットの平均と分散の移動平均も共有されるためです。このパラメータの更新は、両方のデータセットにも適用されます。 2 番目のアプローチでは、上部のネットワークは MNIST データからの平均と分散の推定値を使用してトレーニングされ、下部のネットワークは SVHN データからの平均と分散の推定値を使用してトレーニングされます。しかし、移動平均は 2 つのネットワーク間で共有されるため、移動平均は MNIST データと SVHN データの平均に収束します。

したがって、テスト時に、テスト セットで使用されるバッチ正規化のスケールと変換 (1 つのデータセットの平均) は、モデルが期待するもの (両方のデータセットの平均) とは異なります。テスト用の正規化がトレーニング用の正規化と異なる場合、モデルは次の結果を取得します。

このグラフは、5 つのランダム シードを使用した別の類似データセット (この例では MNIST または SVHN ではありません) での最高、中央値、最低のモデル パフォーマンスを示しています。重みを共有する 2 つのネットワークを使用すると、パフォーマンスが大幅に低下するだけでなく、出力の分散も増加します。

この問題は、単一のデータ ミニバッチがデータ分布全体を代表していない場合に発生します。つまり、入力をシャッフルすることを忘れずにバッチ正規化を使用するのは危険です。これは、最近人気の高い敵対的生成ネットワーク (GAN) でも非常に重要です。 GAN の識別器は通常、偽のデータと実際のデータの混合でトレーニングされます。識別器でバッチ正規化が使用されている場合、純粋に偽のデータのバッチと純粋に実際のデータのバッチを交互に使用するのは誤りです。各小バッチには、両方が均等に混合されている必要があります (それぞれ 50%)。

実際には、バッチ正規化変数を分離し、他の変数を共有するネットワーク構造を使用すると、最良の結果が得られることに注意してください。これは実装が複雑ですが、他の方法よりも確かに効果的です (下の図を参照)。

バッチ正規化:諸悪の根源

上記の問題を考慮して、著者は、可能であればバッチ正規化を使用しないという結論に達しました。

この結論はエンジニアリングの観点から分析されます。

一般的に、コードに問題がある場合、その理由は次の 2 つに限ります。

  1. 明らかに間違いです。たとえば、間違った変数名を入力したり、関数を呼び出すのを忘れたりした可能性があります。
  2. コードには、相互作用する他のコードの動作に対する暗黙の依存関係があり、それらの依存関係の一部が満たされていません。これらのエラーは、コードがどの条件に依存しているかを把握するのに通常長い時間がかかるため、より有害になることがよくあります。

これら両方の間違いは避けられません。 2 番目のタイプのエラーは、より単純な方法を使用し、既存のコードを再利用することで軽減できます。

バッチ正規化方法には、次の 2 つの基本的なプロパティがあります。

  1. トレーニング中、単一の入力 xi の出力はミニバッチ内の他の xj の影響を受けます。
  2. テスト時に、モデルの計算パスが変更されます。正規化にはミニバッチ平均ではなく移動平均が使用されるようになったためです。

これらの特性を持つ最適化方法は他にほとんどありません。これにより、バッチ正規化コードを実装する人は、入力ミニバッチが無相関であるか、トレーニング操作とテスト操作が同じであると想定しやすくなります。このアプローチに疑問を抱く人は誰もいないだろう。

もちろん、バッチ正規化は Java 正規化のブラック ボックス バージョンと考えることができますが、これは非常にうまく機能します。しかし、実際には抽象化には常に漏れがあり、バッチ正規化も例外ではなく、その特性により漏れがさらに生じやすくなります。

なぜ人々はバッチ正規化をあきらめないのでしょうか?

コンピュータ サイエンス コミュニティには、ダイクストラの「GoTo ステートメントは有害である」という有名な記事があります。この中で、ダイクストラは、goto 文はコードを読みにくくするので避けるべきであり、goto を使用するプログラムは goto 文なしで書き直すことができると主張しています。

著者は「バッチ正規化は有害である」という見解を述べたいと思っていますが、十分な理由が見つかりません。結局のところ、バッチ正規化は非常に便利です。

はい、バッチ正規化には問題があります。しかし、すべてを正しく行えば、モデルのトレーニングははるかに速くなります。バッチ正規化の論文が 1400 回以上引用されているのには、十分な理由があります。

バッチ正規化には多くの代替手段がありますが、それらにも独自の欠点があります。レイヤー正規化は、RNN で使用するとより効果的ですが、畳み込みレイヤーで使用すると問題が発生することがあります。重み正規化とコサイン正規化はどちらも比較的新しい正規化方法です。重み正規化の記事では、バッチ正規化が機能しないいくつかの問題に重み正規化を適用できると述べられています。しかし、これらの方法は今のところあまり使われておらず、おそらく時間の問題でしょう。レイヤー正規化、重み正規化、コサイン正規化はすべて、上記のバッチ正規化の問題に対処します。新しい問題に取り組んでいてリスクを負いたい場合には、これらの正規化方法を試してみることをお勧めします。結局、どの方法を使用する場合でも、ハイパーパラメータの調整が必要になります。一度調整すると、さまざまな方法間の違いは小さくなるはずです。

(勇気があれば、バッチ再正規化を試すこともできますが、テスト時には移動平均のみが使用されます。)

バッチ正規化の使用は、ディープラーニングにおける「悪魔の契約」と見なすことができます。得られるものは効率的なトレーニングですが、失うものは異常な結果(狂気)の可能性です。全員がこの契約書に署名します。

翻訳者メモ

「バッチ正規化は有害である」および「バッチ正規化の使用はできるだけ避ける」という著者の見解は、やや極端です。しかし、この記事で言及されているバッチ正規化の罠には注意する必要があります。バッチ正規化の有効性のため、多くのディープラーニング研究者はそれを「魔法のブラックボックス」として扱い、あらゆる可能な場所に適用しています。この単純で大雑把な方法は、トレーニング速度の向上に非常に効果的だからです。しかし、精度の低下をバッチ正規化に帰することは困難です。結局のところ、バッチ正規化によってトレーニングの精度が低下するとは誰も言及していません。

しかし、トレーニング中とテスト中にデータ セットが矛盾する状況は、実際には非常に一般的です。この問題は、翻訳者がトレーニング データ セットを人工的にシミュレートするときに発生します。バッチ正規化を使用する前に、次の問題を慎重に検討することをお勧めします。

  1. トレーニング データセットの各バッチのサンプルは平均化されていますか?
  2. トレーニング データセットのバッチ平均は、テスト中の移動平均と一致していますか?

それ以外の場合は、この記事で説明されている問題を回避するために、次の方法の 1 つ以上を使用する必要があります。

  1. バッチ平均化を確実にするためにトレーニング データセットをランダムにサンプリングします。
  2. 上記の問題を回避するには、記事の例のようにモデルを変更します。
  3. バッチ正規化の代わりにレイヤー正規化、重み正規化、またはコサイン正規化を使用します。
  4. 正規化方法は使用されません。

<<:  CNN の弱点を見つけ、MNIST の「ルーチン」に注意する

>>:  時空間アルゴリズム研究に基づくビジネス意思決定分析

ブログ    
ブログ    
ブログ    
ブログ    
ブログ    
ブログ    

推薦する

中国初の真のAI入力方式が発表され、未来の入力方式を革新する

入力がキーボードに別れを告げ、音声、表現、動作が入力方法になると、どのような魔法のような体験になるの...

...

機械学習業界の発展はなぜ「オープンソース」から切り離せないのか

[[187490]] 2016 年末、Google DeepMind は機械学習プラットフォームであ...

AIを活用して混雑した都市での駐車のストレスを軽減

混雑した市街地でドライバーが駐車スペースを見つけるのを助ける人工知能がバース大学で開発されている。こ...

責任あるAIの構築

現在、AI によって完全に有効化されたプロセスを備えている企業はわずか 25% であり、これらの企業...

Kuaishou AIテクノロジーがゲームチェーン全体に力を与える

導入ゲーム業界は近年急速に発展しており、2020年第1四半期だけでも中国のゲーム市場の売上高は700...

高度な脅威検出における人工知能技術の応用

高度な持続的脅威は、その多様な形態、持続性、対立、隠蔽を特徴とし、現在、大手企業が脅威監視において直...

AIモデルのオープンソースの定義を変える必要がある

オープンソースライセンスは進化すべきだと思いますか? 2023年は人工知能(AI)の登場とともに新年...

WAVE SUMMIT 2023は8月16日に開催予定です!パドルパドルとウェンシンの大型モデルが最新の技術成果を展示します

今年は国内のテクノロジーメーカーが各分野で続々と大型モデルを発売し、「モデル戦争」が本格化しているが...

...

AIがデータ統合の状況をどう変えるのか

生成 AI は統合の状況を変えています。 チームの経済性、速度、プロジェクト構造、配信モデルについて...

...

2021年の3つの主要なAIトレンド:IoT、データ駆動型の意思決定、サイバーセキュリティ

この記事は、公開アカウント「Reading the Core」(ID: AI_Discovery)か...

テクノロジーのホットスポット: 言語的機械学習

[[186484]]昨年から半年以上機械学習を勉強してきましたが、そろそろ総括したいと思います。これ...

調査結果: 回答者の 64% が生成 AI による作業の功績を認めている

Salesforce が実施した調査では、生成 AI の使用に関する明確なポリシーが存在しない状況で...