1.3msかかります!清華大学の最新のオープンソースモバイルニューラルネットワークアーキテクチャRepViT

1.3msかかります!清華大学の最新のオープンソースモバイルニューラルネットワークアーキテクチャRepViT

論文アドレス: https://arxiv.org/abs/2307.09283

コードアドレス: https://github.com/THU-MIG/RepViT

RepViT は確かに他の主流のモバイル ViT アーキテクチャよりも優れていることがわかります。次に、この作品の貢献を見てみましょう。

  1. この論文では、軽量 ViT は一般に、視覚タスクにおいて軽量 CNN よりも優れたパフォーマンスを発揮すると述べられています。これは主に、モデルがグローバル表現を学習できるようにするマルチヘッド自己注意モジュール ( MSHA ) によるものです。ただし、軽量 ViT と軽量 CNN のアーキテクチャ上の違いは十分に研究されていません。
  2. この研究では、著者らは、軽量 ViT の効果的なアーキテクチャ選択を統合することで、標準的な軽量 CNN (特にMobileNetV3 ) のモバイルフレンドリー性を徐々に改善しました。これにより、純粋な軽量 CNN の新しいファミリー、つまりRepViTが誕生しました。RepViT は MetaFormer 構造を持ちますが、完全に畳み込みで構成されていることは注目に値します。
  3. 実験結果では、 RepViT既存の最先端の軽量 ViT を上回り、ImageNet 分類、COCO-2017 でのオブジェクト検出とインスタンス セグメンテーション、ADE20k でのセマンティック セグメンテーションなど、さまざまな視覚タスクにおいて、既存の最先端の軽量 ViT よりも優れたパフォーマンスと効率を示すことが実証されています。特に、 ImageNetでは、 RepViT iPhone 12で1ms近くのレイテンシと80%以上のTop-1精度を達成しており、これは軽量モデルとしては初の躍進となる。

さて、次に誰もが気にするべきことは、「レイテンシが低くても精度の高いモデルをどのように設計するか」です。

方法

ConvNeXtでは、著者らはResNet50アーキテクチャに基づいて厳密な理論的および実験的分析を実施し、最終的にSwin-Transformerに匹敵する非常に優れた純粋な畳み込みニューラル ネットワーク アーキテクチャを設計しました。同様に、 RepViTも主に軽量 ViTs アーキテクチャ設計を標準の軽量 CNN、つまりMobileNetV3-Lに変換し、徐々にアーキテクチャ設計に統合します。このプロセスでは、著者はさまざまな粒度レベルの設計要素を考慮し、一連の手順を通じて最適化の目標を達成しました。

トレーニングレシピの調整

まず、この論文では、モバイル デバイスでのレイテンシを測定するためのメトリックを紹介し、トレーニング戦略を既存の軽量 ViT と一致させます。このステップは主にモデル トレーニングの一貫性を確保するためのもので、レイテンシ測定とトレーニング戦略の調整という 2 つの概念が含まれます。

レイテンシーメトリクス

実際のモバイル デバイス上でのモデルのパフォーマンスをより正確に測定するために、著者は、デバイス上のモデルの実際のレイテンシをベースライン メトリックとして直接測定することを選択しました。このメトリックは、主にFLOPsやモデル サイズなどの指標を通じてモデルの推論速度を最適化するこれまでの研究とは異なりますが、これらの指標はモバイル アプリケーションの実際のレイテンシを必ずしも適切に反映するわけではありません。

トレーニング戦略の調整

ここで、MobileNetV3-L のトレーニング戦略は、他の軽量 ViTs モデルに合わせて調整されます。これには、 AdamWオプティマイザー [ViTs モデルに必要なオプティマイザー] を使用した 5 エポックのウォームアップ トレーニングと、コサイン アニーリング学習率スケジュールを使用した 300 エポックのトレーニングが含まれます。この調整によりモデルの精度はわずかに低下しますが、公平性は保証されます。

ブロック設計の最適化

次に、一貫したトレーニング設定に基づいて、著者らは最適なブロック設計を検討しました。ブロック設計は CNN アーキテクチャの重要な部分であり、ブロック設計を最適化するとネットワークのパフォーマンスが向上します。

トークンミキサーとチャンネルミキサーを分離

この部分は主にMobileNetV3-Lのブロック構造を改良し、トークンミキサーとチャネルミキサーを分離します。オリジナルの MobileNetV3 ブロック構造には、1x1 の拡張畳み込み、それに続く深さ方向の畳み込み、1x1 の投影層が含まれ、その後、入力と出力が残差接続を介して接続されます。これを基に、RepViT は深さ方向の畳み込みを進め、チャネル ミキサーとトークン ミキサーを分離できるようにします。パフォーマンスを向上させるために、トレーニング中にディープ フィルターにマルチ ブランチ トポロジーを導入する構造の再パラメータ化も導入されています。最終的に、著者らは MobileNetV3 ブロック内のトークン ミキサーとチャネル ミキサーを分離することに成功し、このブロックを RepViT ブロックと名付けました。

膨張率を下げて幅を広げる

チャネルミキサーでは、元の拡張比は 4 です。つまり、MLP ブロックの隠し次元は入力次元の 4 倍になり、多くの計算リソースを消費し、推論時間に大きな影響を与えます。この問題を緩和するために、拡張比率を 2 に減らすことで、パラメータの冗長性とレイテンシを削減し、MobileNetV3-L のレイテンシを 0.65 ミリ秒に削減できます。その後、ネットワークの幅を広げる、つまり各ステージのチャネル数を増やすことで、Top-1 の精度は 73.5% に向上しましたが、レイテンシはわずか 0.89 ミリ秒にしか増加しませんでした。

マクロアーキテクチャ要素の最適化

このステップでは、主にステム、ダウンサンプリング レイヤー、分類子、全体のステージ比などのマクロ アーキテクチャ要素から始めて、モバイル デバイス上の MobileNetV3-L のパフォーマンスをさらに最適化します。これらのマクロアーキテクチャ要素を最適化することで、モデルのパフォーマンスを大幅に向上させることができます。

浅いネットワークは畳み込み抽出器を使用する

写真

ViT は通常、入力画像をステムとして重複しないパッチに分割する「パッチ化」操作を使用します。ただし、このアプローチには、トレーニングの最適化とトレーニング レシピに対する感度に関する問題があります。そのため、著者らは代わりに、多くの軽量 ViT で採用されている早期畳み込みを採用しました。対照的に、MobileNetV3-L は 4 倍ダウンサンプリングにさらに複雑なステムを使用します。その結果、初期のフィルター数は 24 に増加しましたが、合計レイテンシは 0.86 ミリ秒に短縮され、トップ 1 精度は 73.9% に向上しました。

より深いダウンサンプリング層

ViT では、空間ダウンサンプリングは通常、別のパッチマージレイヤーを介して実装されます。したがって、ここでは別のより深いダウンサンプリング レイヤーを使用して、ネットワークの深さを増やし、解像度の低下による情報損失を減らすことができます。具体的には、著者らはまず 1x1 畳み込みを使用してチャネル次元を調整し、次に 2 つの 1x1 畳み込みの入力と出力を残差接続を介して接続してフィードフォワード ネットワークを形成します。さらに、ダウンサンプリング層をさらに深くするために RepViT ブロックを前面に追加し、レイテンシ 0.96 ミリ秒でトップ 1 精度を 75.4% に向上しました。

よりシンプルな分類器

軽量 ViT では、分類器は通常、グローバル平均プーリング層とそれに続く線形層で構成されます。対照的に、MobileNetV3-L はより複雑な分類器を使用します。最終ステージにはより多くのチャネルが含まれるようになったため、著者らはこれをグローバル平均プーリング層と線形層で構成される単純な分類器に置き換え、レイテンシを 0.77 ミリ秒に短縮し、トップ 1 の精度を 74.8% にしました。

全体のステージ比率

ステージ比は、異なるステージのブロック数の比率を表し、各ステージでの計算の分散を示します。この論文では、より最適なステージ比 1:1:7:1 を選択し、ネットワークの深さを 2:2:14:2 に増やして、より深いレイアウトを実現しています。このステップにより、レイテンシが 1.02 ミリ秒でトップ 1 の精度が 76.9% に向上します。

マイクロデザインの調整

次に、RepViT は、適切な畳み込みカーネル サイズの選択や、スクイーズ アンド エキシビション (SE) 層の位置の最適化など、レイヤーごとのマイクロ設計を通じて軽量 CNN を調整します。どちらのアプローチでもモデルのパフォーマンスを大幅に向上できます。

畳み込みカーネルサイズの選択

CNN のパフォーマンスとレイテンシは通常、畳み込みカーネルのサイズによって影響を受けることはよく知られています。たとえば、MHSA のような長距離コンテキスト依存性をモデル化するために、ConvNeXt は大きな畳み込みカーネルを使用し、パフォーマンスが大幅に向上します。ただし、大規模な畳み込みカーネルは、計算の複雑さとメモリ アクセス コストの点から、モバイル デバイスには適していません。 MobileNetV3-L は主に 3x3 畳み込みを使用し、一部のブロックは 5x5 畳み込みを使用します。著者らはこれを 3x3 畳み込みに置き換え、その結果、トップ 1 精度を 76.9% 維持しながら、レイテンシを 1.00 ミリ秒に短縮できました。

SE層の位置

畳み込みに対する自己注意モジュールの利点の 1 つは、入力に基づいて重みを調整できることです。これは、データ駆動型プロパティとして知られています。チャネル アテンション モジュールとして、SE レイヤーはデータ駆動型プロパティの欠如による畳み込みの制限を補い、より優れたパフォーマンスをもたらします。 MobileNetV3-L は、主に最後の 2 つのステージで、いくつかのブロックに SE レイヤーを追加します。ただし、解像度の低いステージでは、高解像度のステージよりも、SE が提供するグローバル平均プーリング操作による精度の向上が少なくなります。著者らは、すべてのステージで SE レイヤーをクロスブロック方式で使用する戦略を設計し、レイテンシの増加を最小限に抑えながら精度の向上を最大化しました。このステップにより、トップ 1 の精度が 77.4% に向上し、レイテンシが 0.87 ミリ秒に短縮されました。 [実は、Baidu はずっと以前に実験と比較を行っており、この結論に達しています。SE 層は深層層の近くに配置した方が優れています。]

ネットワークアーキテクチャ

最終的に、上記の改善戦略を統合することで、 RepViT-M1/M2/M3などの複数のバリアントを持つモデルRepViTの全体的なアーキテクチャが得られました。同様に、さまざまなバリアントは主に、各ステージのチャネルとブロックの数によって区別されます。

実験

画像分類

検出とセグメンテーション

要約する

この論文では、軽量 ViT のアーキテクチャ選択を紹介することで、軽量 CNN の効率的な設計を再検討します。これにより、リソースが制限されたモバイル デバイス向けに設計された軽量 CNN の新しいファミリーである RepViT が誕生しました。さまざまな視覚タスクにおいて、RepViT は既存の最先端の軽量 ViT および CNN を上回り、優れたパフォーマンスとレイテンシを示します。これは、モバイル デバイス向けの純粋に軽量な CNN の可能性を浮き彫りにします。

<<:  最新レビュー!拡散モデルと画像編集の愛憎関係

>>: 

ブログ    
ブログ    
ブログ    
ブログ    
ブログ    

推薦する

機械学習とデータサイエンスに関する必読の無料オンライン電子書籍 10 冊

KDnuggets 編集者の Matthew Mayo が、機械学習とデータ サイエンスに関連する書...

クレジットカード詐欺を検出するための機械学習モデルを構築するにはどうすればよいでしょうか?

[[187627]]機械学習は、Apple の Siri や Google のアシスタントなどのス...

Python での機械学習アルゴリズムの実装: ニューラル ネットワーク

今日は引き続き、パーセプトロンをベースにしたニューラルネットワークモデルを紹介します。パーセプトロン...

...

李開復「2021年を予測」:4つの主要分野が前例のない発展の機会をもたらす

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

TFserving によるディープラーニング モデルの導入

1. TFservingとは何かモデルをトレーニングし、それを外部の関係者に提供する必要がある場合は...

996の非効率性にノーと言いましょう: ChatGPTはコードコメントとドキュメントを簡単に処理するのに役立ちます

適切なコメントは、Python プロジェクトを成功させる上で非常に重要です。実際には、コメントを書く...

ベンジオ、ヒントン、張亜琴らAI界の巨人たちが新たな共同書簡を発表! AIは危険すぎるので、再配置する必要がある

AI リスク管理は、AI 大手企業によって再び議題に挙げられています。ちょうど今、ベンジオ、ヒントン...

最も孤独なニューラル ネットワーク: たった 1 つのニューロンですが、「クローンをシャドウ」することができます

世界で最も先進的なニューラルネットワークモデルは何ですか?それは人間の脳に違いない。人間の脳には86...

...

サイバー犯罪におけるAI時代の到来

人工知能の分野で画期的な進歩が起こったばかりであり、サイバーセキュリティに携わっている人であれば、そ...

...

高性能 HTTP サーバーの負荷分散アルゴリズムは何ですか?ほとんどのプログラマーは収集しています...

典型的な高同時実行性、大規模ユーザー Web インターネット システムのアーキテクチャ設計では、HT...

開発者の「第2の脳」が登場、GitHub Copilotがアップデートされ、人間の開発参加がさらに減少

Andrej Karpathy 氏が嘆くのは、ソフトウェア開発プロセスにおいてコードを直接記述するこ...

機械学習の実践: Spark と Python を組み合わせるには?

Apache Sparkはビッグデータの処理や活用に最も広く使われているフレームワークの一つであり...