決定木のルネッサンス?ニューラルネットワークと組み合わせることで、ImageNetの分類精度が向上し、解釈が容易になります。

決定木のルネッサンス?ニューラルネットワークと組み合わせることで、ImageNetの分類精度が向上し、解釈が容易になります。

ケーキも食べて、ケーキも残したいです! BAIR は、正確性と解釈可能性のバランスをとったニューラルサポート決定木に関する新しい研究を発表しました。

金融や医療などの分野でディープラーニングが継続的に実装されるにつれて、モデルの解釈可能性は非常に大きな問題点になっています。これらの分野では、正確な予測を行い、その動作を説明できるモデルが求められるためです。しかし、ディープ ニューラル ネットワークは解釈可能性に欠けることでも有名であり、これは矛盾を生じさせます。説明可能な AI (XAI) は、モデルの精度と説明可能性の矛盾のバランスを取ろうとしますが、決定の理由を説明する際にモデル自体を直接説明することはありません。

決定木は分類のための古典的な機械学習手法です。理解と解釈が容易で、中規模のデータに対して低い難易度で優れたモデルを得ることができます。以前人気があった Microsoft XiaoIce の思考読み取り機能は、おそらく決定木を使用していました。 Xiaobing はまず、よく知られている人物 (ある程度有名な人物) を想像するように求め、その後 15 個以下の質問をします。「はい」「いいえ」「わかりません」で答えるだけで、Xiaobing は私たちが考えている人物が誰であるかをすぐに推測できます。

周志華先生はかつて『西瓜本』の中で、意思決定ツリーの図を示しました。

決定木図。

決定木には多くの利点がありますが、過去の経験から、ImageNet レベルのデータに直面した場合、そのパフォーマンスはニューラル ネットワークに比べてまだはるかに劣っていることがわかります。

「正確性」と「説明可能性」、どうすれば両方の長所を両立できるのでしょうか?これら 2 つを組み合わせると何が起こるでしょうか?最近、カリフォルニア大学バークレー校とボストン大学の研究者がこのアイデアを実践しました。

彼らはニューラルネットワークを活用した決定木を提案し、ImageNet でトップ 1 分類精度 75.30% を達成しました。決定木の解釈可能性を維持しながら、現在のニューラル ネットワークでのみ達成可能な精度を達成しました。これは、決定木に基づく他の画像分類方法よりも約 14% 高い精度です。

BAIR ブログアドレス: https://bair.berkeley.edu/blog/2020/04/23/decisions/

論文アドレス: https://arxiv.org/abs/2004.00221

オープンソース プロジェクト アドレス: https://github.com/alvinwan/neural-backed-decision-trees

この新しく提案された方法はどの程度解釈可能でしょうか? 2枚の写真を見てみましょう。

OpenAI Microscope で視覚化されたディープ ニューラル ネットワークは次のようになります。

提案手法によるCIFAR100での分類の可視化結果は以下のとおりです。

画像分類においてどの方法が強力な解釈可能性を持っているかは明らかです。

決定木の利点と欠点

ディープラーニングが普及する前は、決定木が正確性と解釈可能性のベンチマークでした。以下では、まず決定木の解釈可能性について説明します。

上の図に示すように、この決定木は入力データ x (「スーパーバーガー」または「ワッフルフライ」) の予測結果を提供するだけでなく、最終的な予測につながる一連の中間決定も出力します。私たちはこれらの中間決定を検証したり、異議を申し立てたりすることができます。

しかし、画像分類データセットでは、決定木の精度はニューラル ネットワークの精度より 40% 遅れています。ニューラル ネットワークと決定木の組み合わせもパフォーマンスが低く、CIFAR10 データセットのニューラル ネットワークとは比較になりませんでした。

この精度の欠陥により、解釈可能性の利点が「無価値」になります。まず、高精度のモデルが必要ですが、このモデルは解釈可能でなければなりません。

ニューラルサポートによる決定木へのアプローチ

今、このジレンマはようやく解決に向かっています。カリフォルニア大学バークレー校とボストン大学の研究者たちは、解釈可能かつ正確なモデルを構築することでこの問題に取り組みました。

この研究の重要なポイントは、ニューラル ネットワークと決定木を組み合わせ、低レベルの意思決定にニューラル ネットワークを使用しながら、高レベルの解釈可能性を維持することです。下の図に示すように、研究者はこのモデルを「ニューラル バック決定木 (NBDT)」と呼んでおり、このモデルは決定木の解釈可能性を維持しながらニューラル ネットワークの精度に匹敵できると述べています。

この図では、各ノードにニューラル ネットワークが含まれています。上の図は、そのようなノードの 1 つと、そこに含まれるニューラル ネットワークを拡大表示したものです。この NBDT では、予測は決定木を通じて行われ、高いレベルの解釈可能性が維持されます。しかし、決定木の各ノードには、低レベルの決定を行うニューラル ネットワークがあります。たとえば、上の図のニューラル ネットワークによって行われた低レベルの決定は、「ソーセージあり」または「ソーセージなし」です。

NBDT は決定木と同じ解釈可能性を持ちます。また、NBDT は予測結果の中間決定を出力できるため、現在のニューラル ネットワークよりも優れています。

下の図に示すように、「犬」を予測するネットワークでは、ニューラル ネットワークは「犬」のみを出力する可能性がありますが、NBDT は「犬」と他の中間結果 (動物、脊索動物、肉食動物など) を出力できます。

さらに、NBDT の予測階層トレースも視覚化され、どの可能性が拒否されたかが示されます。

同時に、NBDT はニューラル ネットワークに匹敵する精度を達成しました。 CIFAR10、CIFAR100、TinyImageNet200などのデータセットでは、NBDTの精度はニューラルネットワークの精度に近い(ギャップ

ニューラルネットワークを利用した決定木の説明

個人予測の弁証法的根拠

最も有益な弁証法的な理由は、モデルがこれまで見たことのないオブジェクトに向けられたものです。たとえば、NBDT (以下に示す) を検討し、それを Zebra で実行します。モデルはシマウマを見たことはありませんが、下の図に示す中間決定は正しいです。シマウマは動物であり、有蹄類でもあります。これまで見たことのない物体の場合、個々の予測の妥当性は非常に重要です。

モデルの行動に対する弁証法的正当化

さらに研究者らは、NBDT を使用すると、精度が向上するにつれて解釈可能性も向上することを発見しました。これは、記事の冒頭で紹介した正確性と解釈可能性の対立に反するものであり、つまり、NBDT は正確性と解釈可能性を備えているだけでなく、正確性と解釈可能性を同じ目標にしているのです。

ResNet10 階層 (左) は WideResNet 階層 (右) よりも劣っています。

たとえば、ResNet10 は CIFAR10 では WideResNet28x10 よりも 4% 精度が低くなります。同様に、精度の低い ResNet^6 階層 (左) では、カエル、猫、飛行機がグループ化されていますが、3 つのクラスすべてに共通する視覚的特徴を見つけるのが難しいため、あまり意味がありません。対照的に、より正確な WideResNet 階層 (右) はより意味があり、動物と車を完全に分離します。したがって、精度が高ければ高いほど、NBDT の解釈が容易になると言えます。

意思決定ルールを理解する

低次元の表形式データを使用する場合、決定木の決定ルールは簡単に解釈できます。たとえば、皿にパンがある場合は、それを適切な子に割り当てます (以下に示すように)。しかし、高次元画像のような入力の場合、決定ルールはそれほど単純ではありません。モデルの決定ルールは、オブジェクトの種類だけでなく、コンテキスト、形状、色などの要素にも基づいています。

この例では、低次元の表形式データを使用して意思決定ルールを簡単に説明する方法を示します。

決定ルールを定量的に説明するために、研究者は WordNet3 の既存の名詞階層を使用しました。この階層により、カテゴリ間で最も具体的な共通の意味を見つけることができます。たとえば、カテゴリ Cat と Dog が指定されている場合、WordNet は Mammal を返します。下の図では、研究者がこれらの WordNet の仮定を定量的に検証しています。

左の依存ツリー (赤い矢印) の WordNet 仮説は Vehicle です。右側の WordNet 仮説 (青い矢印) は Animal です。

注目すべきは、10 クラスの小さなデータセット (CIFAR10 など) では、研究者がすべてのノードに対して WordNet 仮説を見つけることができることです。ただし、1000 個のカテゴリを持つ大規模なデータセット (ImageNet など) では、ノードのサブセットでのみ WordNet 仮説が見つかります。

仕組み

ニューラル バック決定木のトレーニングと推論のプロセスは、次の 4 つのステップに分解できます。

決定木の誘導階層と呼ばれる階層を構築します。

このレイヤーは、Tree Supervision Loss と呼ばれる独自の損失関数を生成します。

推論は、サンプルをニューラル ネットワーク バックボーンに渡すことから始まります。最後の完全に接続された層の前では、バックボーン ネットワークはすべてニューラル ネットワークです。

推論は、最後の完全に接続された層を順次決定ルール方式で実行することで終了し、研究者はこれを埋め込み決定ルールと呼んでいます。

ニューラル バック デシジョン ツリーのトレーニングと推論の概略図。

埋め込まれた決定ルールを実行する

ここではまず推論の問題について議論します。前述したように、NBDT はニューラル ネットワーク バックボーンを使用して各サンプルの特徴を抽出します。次の操作を理解しやすくするために、研究者はまず、次の図に示すように、完全接続層と同等の退化決定木を構築しました。

上記は行列とベクトルの乗算を生成し、それがベクトルの内積となり、ここでは $\hat{y}$ と表記されます。上記の最大出力値のインデックスがカテゴリーの予測となります。

単純な決定木: 研究者は、上の図の「B - 単純な」に示すように、カテゴリごとに 1 つのルート ノードと 1 つのリーフ ノードのみを持つ基本的な決定木を構築しました。各リーフ ノードはルート ノードに直接接続されており、表現ベクトル (W からの行ベクトル) を持ちます。

サンプルから抽出された特徴 x を推論に使用するということは、各子ノード表現ベクトルと x の内積を計算することを意味します。完全接続層と同様に、最大内積のインデックスが予測カテゴリになります。

完全に接続された層と単純な決定木との間の直接的な同等性から、研究者は特別な推論方法、つまり内積を使用する決定木を提案しました。

帰納的階層の構築

このレベルでは、NBDT が決定する必要があるカテゴリのセットが決定されます。この層の構築には事前にトレーニングされたニューラル ネットワークの重みが使用されるため、研究者はこれを誘導層と呼んでいます。

具体的には、研究者らは、上図の「ステップ B」に示すように、完全結合層の重み行列 W の各行ベクトルを d 次元空間内の点とみなします。次に、これらのポイントに対して階層的クラスタリングが実行されます。この階層は、連続的なクラスタリングの後に生成されます。

ツリー教師あり損失を使用したトレーニング

上の図の「A-Hard」の状況を考えてみましょう。緑色のノードが Horse クラスに対応していると仮定します。これは単なるクラスですが、動物(オレンジ)でもあります。その結果、ルートノード (青) に到達するサンプルは右側の動物に配置されるはずであることもわかります。 「動物」ノードに到達したサンプルも、再び右に曲がって「馬」に向かう必要があります。各ノードは正しい子ノードを予測するようにトレーニングされます。研究者たちは、この損失を強制する木を「Tree Supervision Loss」と呼んでいます。言い換えれば、これは実際には各ノードのクロスエントロピー損失です。

使用ガイドライン

Python パッケージ管理ツールを使用して nbdt を直接インストールできます。

  1. pip インストール nbdt

nbdt をインストールすると、任意の画像に対して推論を実行できます。nbdt は Web リンクまたはローカル画像をサポートしています。

  1. 出典: nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32  
  2.  
  3. # またはローカルイメージで実行
  4.  
  5. nbdt /imaginary/path/to/local/image.png

インストールしたくない場合は、問題ありません。研究者は、次のアドレスで Web バージョンのデモと Colab の例を提供しています。

デモ: http://nbdt.alvinwan.com/demo/

コラボ: http://nbdt.alvinwan.com/notebook/

次のコードは、研究者が推論用に提供した事前トレーニング済みモデルを使用する方法を示しています。

  1. nbdt.model からSoftNBDTをインポートします
  2.  
  3. from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10 # TinyImagenet200にはwrn28_10 を使用します
  4.  
  5. モデル = wrn28_10_cifar10()
  6.  
  7. モデル = SoftNBDT(
  8.  
  9. 事前トレーニング済み=True、
  10.  
  11. データセット = 'CIFAR10'
  12.  
  13. アーチ = 'wrn28_10_cifar10'
  14.  
  15. モデル=モデル)

さらに、研究者らは、6 行未満のコードで nbdt を独自のニューラル ネットワークと組み合わせる方法も提供しました。詳細については、GitHub オープン ソース プロジェクトをご覧ください。

<<:  これが顔認識と画像認識がますます重要になっている理由です

>>:  専門家が最もよく使う機械学習ツール 15 選

ブログ    
ブログ    

推薦する

第14次5カ年計画期間中、我が国のドローン産業の発展はますます明確になりました

[[421133]]ドローン産業の発展レベルは、国の軍事力、科学技術革新、製造レベルを測る重要な指標...

人工知能は意識のギャップを埋めることができるか?

諺にもあるように、千人の読者には千のハムレットがあり、私たちにとって人工知能 (AI) も同じことが...

...

LLM にとってベクター データベースが重要なのはなぜですか?

翻訳者 |ブガッティレビュー | Chonglou Twitter 、 LinkedIn 、またはニ...

CLIP と LLM を使用したマルチモーダル RAG システムの構築

この記事では、オープンソースの Large Language Multi-Modal モデルを使用し...

我が国の自動販売機の現状と展望はどうなっているのでしょうか? Pinshi Intelligentは新たな戦略を持っています

セルフサービス自動販売機自体は目新しいものではないが、販売品目が普通のボトル入り飲料から絞りたてジュ...

Tongyi Qianwenが再びオープンソース化、Qwen1.5は6つのボリュームモデルを導入、そのパフォーマンスはGPT3.5を上回る

春節の直前に、同義千文モデル(Qwen)バージョン1.5がリリースされました。今朝、新バージョンのニ...

C# バイナリ ツリー トラバーサル アルゴリズムの実装の簡単な分析

C# アルゴリズムは、バイナリ ツリーの定義、既知のバイナリ ツリーの構築方法、および C# でバイ...

産業用 IoT が人工知能の時代へ

インテリジェンスは近年、製造業における最も重要なトレンドです。過去数年間の市場教育を経て、過去2年間...

Open LLM リストが再び更新されました。Llama 2 よりも強力な「Duckbill Puss」が登場します。

OpenAI の GPT-3.5 や GPT-4 などのクローズドソース モデルの優位性に挑戦する...

自動運転:「乗っ取り」という言葉を恐れるのをやめよう

編集者注:過去2年間、ロボタクシーの公共運行は中国の多くの場所で開花しました。これらのロボタクシーに...

人工知能が人間の神経を刺激し、2017年は世界的な技術革新が活発化

[[183471]]図1:2017年1月7日、知能ロボット「小宝」が上海市楊浦区のショッピングモール...

人気のディープラーニングライブラリ23選のランキング

[[209139]] Data Incubator は最近、Github と Stack Overf...

機械学習が金融業界にもたらす破壊的変化

過去 10 年間で、金融業界ではこれまでにない最先端のテクノロジーが数多く導入されました。この変化は...

ロボット革命はビジネス環境を変えている

今世紀の前半には、巨大な片腕の巨人のような産業用ロボットがロボット工学の分野を支配していました。産業...