ディープニューラルネットワークをデバッグするにはどのような方法を使用しますか? 4つの簡単な方法をご紹介します

ディープニューラルネットワークをデバッグするにはどのような方法を使用しますか? 4つの簡単な方法をご紹介します

データセットの構築、ニューラル ネットワークのコーディング、モデルのトレーニングに何週間も費やした後、結果が満足のいくものではないことに気付いた場合はどうしますか?

ディープラーニングはブラックボックスとして見られることが多く、私もそれに異論はありませんが、学習された何万ものパラメータの意味を説明できますか?

しかし、ブラックボックスの観点は、機械学習の実践者にとって明白な疑問を提起します。モデルをどのようにデバッグするのか?

この記事では、Cardiogram で Apple Watch、Garmin、WearOS のデータを使用して病気を予測するディープ ニューラル ネットワークである DeepHeart をデバッグするために使用したいくつかの手法について説明します。

Cardiogram では、DNN の構築は錬金術ではなくエンジニアリングであると考えています。

あなたの心はあなたについて多くのことを明らかにしてくれます。 DeepHeart は、Apple Watch、Garmin、WearOS からの心拍数データを使用して、糖尿病、高血圧、睡眠時無呼吸のリスクを予測します。

1. 合成出力の予測

入力データから構築された合成出力タスクを予測して、モデルの機能をテストします。

私たちはこの技術を睡眠時無呼吸を検出するモデルの構築に使用しました。睡眠時無呼吸スクリーニングに関する既存の文献では、スクリーニングのメカニズムとして、昼間と夜間の心拍数の標準偏差の差を使用しています。そこで、入力データの各週ごとに合成出力タスクを作成しました。

標準偏差(昼間の心拍数) – 標準偏差(夜間の心拍数)

この機能を学習するには、モデルは次のことができる必要があります。

  • 昼と夜の区別
  • 過去数日間のデータを思い出す

これらは両方とも睡眠時無呼吸を予測するための前提条件であるため、新しいアーキテクチャを試す最初のステップは、この合成タスクを学習できるかどうかを確認することでした。

合成タスクに関してネットワークを事前トレーニングすることにより、このような合成タスクを半教師あり方式で使用することもできます。このアプローチは、ラベル付きデータが少なく、ラベルなしデータが大量にある場合に役立ちます。

2. 活性化値の可視化

トレーニングされたモデルの内部の仕組みを理解するのは困難です。何千もの行列乗算をどうやって理解するのでしょうか?

この優れた Distill の記事「ニューラル ネットワークを使用した手書きの 4 つの実験」では、著者らがヒートマップにユニットのアクティベーションをプロットして手書きモデルを分析しています。これは DNN の「ボンネットを開ける」ための優れた方法であることがわかりました。

私たちは、ネットワーク内のいくつかの層の活性化を調べ、たとえば、ユーザーが眠っているとき、仕事をしているとき、または不安なときに活性化されるユニットは何かといった意味的な特性を発見することを期待しました。

モデルからアクティベーションを抽出するために Keras で記述されたコードはシンプルです。次のコード スニペットは、入力データを指定してレイヤーの出力 (つまり、そのアクティベーション値) を取得する Keras 関数 last_output_fn を作成します。

  1. kerasからバックエンドをKとしてインポートします
  2.  
  3. def extract_layer_output(モデル、レイヤー名、入力データ):
  4. レイヤー出力関数= K .function([モデル.レイヤー[0].入力],
  5. [model.get_layer(レイヤー名).出力])
  6.  
  7. レイヤー出力=レイヤー出力関数([入力データ])
  8.  
  9. # layer_output.shape は (num_units, num_timesteps) です
  10. レイヤー出力[0]を返す

ネットワークのいくつかの層の活性化を視覚化しました。 2 番目の畳み込み層 (幅 128 の時間畳み込み層) の活性化を調べていたところ、奇妙なことに気付きました。

各タイムステップにおける畳み込み層の各ユニットの活性化値。青い網掛けはアクティベーション値を表します。

活性化値は時間の経過とともに変化しません。入力値の影響を受けず、「デッドニューロン」と呼ばれます。

ReLU活性化関数、f(x) = max(0, x)

このアーキテクチャは、入力が負の場合に 0 を出力する ReLU 活性化関数を使用します。これはニューラル ネットワークの比較的浅い層ですが、まさにこのようなことが起こっています。

トレーニング中のある時点で、大きな勾配によって特定のレイヤーのすべてのバイアス項が負の数に変換され、ReLU 関数の入力が小さな負の数になります。したがって、この層の出力はすべて 0 になります。入力が 0 未満の場合、ReLU の勾配は 0 であり、この問題は勾配降下法では解決できないためです。

畳み込み層の出力がすべてゼロの場合、後続の層のユニットはそのバイアス項の値を出力します。このレイヤーの各ユニットが異なる値を出力するのは、バイアス項が異なるためです。

この問題は、ReLU を Leaky ReLU に置き換えることで解決します。Leaky ReLU により、入力が負の場合でも勾配が伝播できるようになります。

この分析で「死んだニューロン」が見つかるとは思っていませんでしたが、見つけるのが最も難しいエラーは、見つけるつもりのないエラーです。

3. 勾配分析

もちろん、勾配の役割は損失関数を最適化するだけではありません。勾配降下法では、Δパラメータに対応するΔ損失を計算します。ただし、勾配は通常、ある変数を変更すると別の変数にどのような影響が及ぶかを判断するために計算されます。勾配降下法では勾配の計算が必要となるため、TensorFlow などのフレームワークでは勾配を計算する関数が提供されています。

勾配分析を使用して、ディープ ニューラル ネットワークがデータ内の長期的な依存関係を捉えられるかどうかを判断します。 DNN の入力データは非常に長く、心拍数または歩数データの 4096 時間ステップです。私たちのモデル アーキテクチャがデータ内の長期的な依存関係を捉えられることは非常に重要です。たとえば、心拍数の回復時間から糖尿病を予測できます。これは、運動後に安静時の心拍数に戻るまでにかかる時間です。これを計算するには、ディープ ニューラル ネットワークが、休憩中の心拍数を把握し、トレーニングを終了した時間を記憶できる必要があります。

モデルが長期的な依存関係を追跡できるかどうかを測定する簡単な方法は、入力データの各タイムステップが出力予測に与える影響を調べることです。後の時間ステップの影響が特に大きい場合、モデルは以前のデータを効果的に使用していません。

すべての時間ステップ t について、計算する勾配は Δinput_t に対する Δoutput です。 Keras と TensorFlow を使用してこの勾配を計算するコード例を次に示します。

  1. def gradient_output_wrt_input(モデル、データ):
  2. # [:, 2048, 0] はバッチ内のすべてのユーザー、中間タイムステップ、0 番目のタスク (糖尿病) を意味します
  3. output_tensor =モデル.model.get_layer('raw_output').output[:, 2048, 0]
  4. # output_tensor.shape == (num_users)
  5.  
  6. # すべてのユーザーの平均出力。結果はスカラーです。
  7. 出力テンソル合計= tf.reduce_mean (出力テンソル)
  8.  
  9. 入力=モデル.model.inputs # (num_users x num_timesteps x num_input_channels)
  10. 勾配テンソル= tf.gradients (出力テンソル合計、入力)
  11. # gradient_tensors.shape == (num_users x num_timesteps x num_input_channels)
  12.  
  13. # ユーザー全体の平均
  14. 勾配テンソル= tf .reduce_mean(勾配テンソル、= 0 )
  15. # gradient_tensors.shape == (num_timesteps x num_input_channels)
  16. # 例えば gradient_tensor[10, 0] は最後の出力の 10 番目の入力心拍数に対する微分です
  17.  
  18. # Keras関数に変換する
  19. k_gradients = K .function(入力inputs = 入力、出力= gradient_tensors )
  20.  
  21. # データセットに関数を適用する
  22. k_gradients([data.X]) を返す

上記のコードでは、平均プーリングの前に中間点の時間ステップ 2048 で出力を計算しました。最後のタイム ステップではなく中間点を使用する理由は、LSTM セルが双方向であるためです。つまり、セルの半分では、4095 が実際には最初のタイム ステップになります。結果として得られる勾配を視覚化します。

Δ出力_2048 / Δ入力_t

Y 軸は対数スケールであることに注意してください。タイムステップ 2048 では、入力に対する出力の勾配は 0.001 です。しかし、タイムステップ 2500 では、対応する勾配は 100 万倍小さくなります。勾配分析により、このアーキテクチャでは長期的な依存関係をキャプチャできないことがわかりました。

4. 分析モデル予測

AUROC や平均絶対誤差などの指標を調べて、モデルの予測を分析したことがあるかもしれません。さらに分析を行ってモデルの動作を理解することもできます。

たとえば、DNN が実際に心拍数入力を使用して予測を生成するのか、それともその学習が提供されたメタデータに大きく依存するのかを知りたいと思い、性別や年齢などのユーザー メタデータを使用して LSTM 状態を初期化しました。これを理解するために、私たちはモデルをメタデータでトレーニングされたロジスティック回帰モデルと比較しました。

DNN モデルは 1 週間分のユーザー データを受信するため、下の散布図では各ドットが 1 週間分のユーザー データを表します。

このグラフは、予測の相関性があまり高くないことから、私たちの仮説を裏付けています。

集計分析を実行するだけでなく、最良と最悪のサンプルを調べることも有益です。バイナリ分類タスクの場合、最もひどい偽陽性と偽陰性(つまり、予測がラベルから最も遠いケース)を調べる必要があります。損失パターンを特定し、真の陽性と真の陰性に現れるパターンを除外します。

損失パターンについての仮説ができたら、層別分析を通じてそれをテストします。たとえば、最も高い損失がすべて第 1 世代の Apple Watch から発生している場合、第 1 世代の Apple Watch を使用してチューニング セット内のユーザー セットの精度メトリックを計算し、それらのメトリックを残りのチューニング セットで計算されたメトリックと比較できます。

オリジナルリンク: https://blog.cardiogr.am/4-ways-to-debug-your-deep-neural-network-e5edb14a12d7

[この記事は51CTOコラム「Machine Heart」、WeChatパブリックアカウント「Machine Heart(id:almosthuman2014)」によるオリジナル翻訳です]

この著者の他の記事を読むにはここをクリックしてください

<<:  人工知能の知られざる歴史: 目に見えない女性プログラマーたち

>>:  スーパーパートナー:IoT、AI、クラウドが強力な同盟を形成

ブログ    
ブログ    
ブログ    

推薦する

中国の科学者が色を変えることができる柔らかいロボットを開発

ああ、これはまだ私が知っているロボットですか? 「カモフラージュして色を変える」と「柔らかく変形する...

データサイエンスにおける一般的な課題は何ですか?

2017 年後半を迎えるにあたり、データ サイエンスと機械学習を活用する企業が直面する共通の課題に...

...

ICLR 2022: AI が「目に見えないもの」を認識する方法

この記事はAI新メディアQuantum Bit(公開アカウントID:QbitAI)より許可を得て転載...

GPT-2はGPT-4を監督できる、イリヤがOpenAI初のスーパーアライメント論文を主導:AIアライメントAIは実証的な結果を達成

過去1年間、「次のトークンを予測する」ことを本質とする大規模なモデルが人間の世界の多くのタスクに浸透...

AI医用画像の春が再び到来?

概要: AI医用画像診断市場は急速な成長期を迎えつつあり、医師の負担を軽減しながら医療の質の向上も期...

AI は言語をより早く習得するために何ができるでしょうか?

新しい言語を学ぶことは間違いなく挑戦です。特に 18 歳以上の人にとっては、これまで触れたことのない...

CIOがAIのビジネスケースを作成する方法

近年、AI プロジェクトに対する組織の関心は着実に高まっています。調査会社ガートナーの調査によると、...

...

網膜症治療のAIが成熟する中、なぜ医療業界は「無反応」なのか?

網膜は人体の中で唯一、血管や神経細胞の変化を非侵襲的に直接観察できる組織であり、さまざまな慢性疾患の...

AI導入において、テクノロジーは最大の課題ではないが、人材は

[[427056]]写真: ゲッティ従来型企業の経営幹部が人工知能 (AI) や機械学習 (ML) ...

AIを活用してよりスマートな電子データ交換を実現

電子データ交換 (EDI) の歴史は、企業がより効率的に電子的にデータを交換する方法を模索し始めた ...

AI インテリジェント音声認識アルゴリズム パート 2

[[397599]] 1. ニューラルネットワーク現在一般的に使用されている音声認識フレームワーク...