PyTorchBigGraph を使用して超大規模グラフ モデルをトレーニングする方法は?

PyTorchBigGraph を使用して超大規模グラフ モデルをトレーニングする方法は?

Facebook は、数十億のノードと数兆のエッジを持つグラフ モデルを効率的にトレーニングできる BigGraph というフレームワークを提案し、その PyTorch 実装をオープンソース化しました。この記事では、その革新性を説明し、大規模なグラフ ネットワークから効率的に知識を抽出できる理由を分析します。

グラフは、機械学習アプリケーションにおける最も基本的なデータ構造の 1 つです。具体的には、グラフ埋め込み法は、ローカル グラフ構造を使用してノードの表現を学習する教師なし学習法です。ソーシャル メディア予測、IoT パターン検出、薬物シーケンス モデリングなどの主流のシナリオにおけるトレーニング データは、グラフ構造として自然に表現できます。これらのシナリオのそれぞれでは、数十億個の接続されたノードを持つグラフが簡単に作成されます。グラフは構造が非常に豊富で、本質的にナビゲート可能であるため、機械学習モデルに適しています。それにもかかわらず、グラフ構造は非常に複雑であり、アプリケーションに合わせて拡張するのが困難です。したがって、最新のディープラーニング フレームワークでは、大規模なグラフ データ構造のサポートが依然として非常に限られています。

Facebook は PyTorch BigGraph: https://github.com/facebookresearch/PyTorch-BigGraph というフレームワークを立ち上げました。これにより、PyTorch モデル内の大規模なグラフ構造のグラフ埋め込みをより迅速かつ簡単に生成できます。

ある意味では、グラフ構造は、ノード間の接続を使用して特定の関係を推測できるため、ラベル付けされたトレーニング データセットの代替として考えることができます。このアプローチは、教師なしグラフ埋め込み法のパターンに従います。この方法では、エッジで接続されたノード ペアの埋め込みがエッジのないノード ペアの埋め込みよりも近くなるようにノード ペアの埋め込みを最適化することで、グラフ内の各ノードのベクトル表現を学習できます。これは、テキストでトレーニングされた word2vec からの単語埋め込みが機能する方法に似ています。

ほとんどのグラフ埋め込み方法は、大規模なグラフ構造に適用すると、かなり限られた結果しか示しません。たとえば、モデルに 20 億のノードがあり、各ノードに 100 個の埋め込みパラメータ (浮動小数点数として表される) がある場合、これらのパラメータを格納するためだけに 800 GB のメモリが必要になるため、多くの標準的なアプローチでは一般的なコモディティ サーバーのメモリ容量を超えてしまいます。これはディープラーニング モデルが直面している大きな課題であり、Facebook が BigGraph フレームワークを開発した理由です。

PyTorch ビッググラフ

PyTorch BigGraph (PBG) の目標は、グラフ埋め込みモデルを拡張して、数十億のノードと数兆のエッジを持つグラフを処理することです。 PBG はなぜこれができるのでしょうか? 4 つの基本的な構成要素を使用するためです。

  1. グラフのパーティション分割により、モデルをメモリに完全にロードする必要がなくなります。
  2. 各マシンでのマルチスレッドコンピューティング
  3. 複数のマシン上での分散実行(オプション)。すべての操作はグラフの切断された部分で実行されます。
  4. バッチネガティブサンプリングでは、エッジごとに 100 個のネガティブサンプルがある場合、マシンごとに 1 秒あたり 100 万を超えるエッジを処理できます。

PBG は、グラフ構造を P 個のランダムに分割されたパーティションに分割し、2 つのパーティションがメモリに収まるようにすることで、従来のグラフ埋め込み方法の欠点の一部を解決します。たとえば、エッジがパーティション p1 で始まり、パーティション p2 で終わる場合、そのエッジはバケット (p1、p2) に配置されます。次に、同じモデル内で、これらのグラフ ノードはソース ノードとターゲット ノードに応じて P2 バケットに分割されます。ノードとエッジの分割が完了したら、一度に 1 つのバケットでトレーニングを実行できます。バケット (p1、p2) のトレーニングでは、パーティション p1 と p2 の埋め込みをメモリに保存するだけで済みます。 PBG 構造により、バケットには少なくとも 1 つの以前にトレーニングされた埋め込みパーティションが含まれるようになります。

PBG のもう一つの大きな革新は、トレーニング メカニズムの並列化と分散です。 PBG は PyTorch 独自の並列化メカニズムを使用して、上記のモジュール分割構造を使用する分散トレーニング モデルを実装します。このモデルでは、各マシンが分離したバケットでのトレーニングを調整します。これは、バケットをワーカーにディスパッチする役割を果たすロック サーバーを使用し、異なるマシン間の通信を最小限に抑えます。各マシンは異なるバケットを使用してモデルを並列にトレーニングできます。

上の図では、マシン 2 の Trainer モジュールがマシン 1 のロック サーバーにバケットを要求し、バケットのパーティションをロックします。次に、トレーナーは使用しなくなったパーティションを保存し、共有パーティション サーバーから必要な新しいパーティションをロードします。この時点で、古いパーティションをロック サーバーに戻すことができます。次に、エッジは共有ファイル システムからロードされ、スレッド内同期なしで複数のスレッドでトレーニングされます。別のスレッドでは、いくつかの共有パラメータのみが共有パラメータ サーバーと継続的に同期されます。モデル チェックポイントは、トレーナーから共有ファイル システムに時々書き込まれます。このモデルでは、最大 P/2 台のマシンを使用して P 個のバケットのセットを並列化できます。

PBG のそれほど直接的ではない革新は、バッチネガティブサンプリングの使用です。従来のグラフ埋め込みモデルは、負のトレーニング例として、真の正のエッジとともにランダムな「偽の」エッジを構築します。これにより、新しい例ごとに重みのごく一部だけを更新すればよいため、トレーニングの速度が大幅に向上します。ただし、負の例はグラフ処理にパフォーマンスのオーバーヘッドをもたらし、ランダムなソース ノードまたはターゲット ノードを通じて実際のエッジを「破損」させる可能性があります。 PBG は、N 個のランダム ノードの単一バッチを再利用して、N 個のトレーニング エッジの破損した負のサンプルを取得する方法を導入します。他の埋め込み方法と比較して、この手法では、計算コストを非常に低く抑えながら、エッジごとに多数の負の例をトレーニングできます。

大規模なグラフでのメモリ効率と計算リソースを向上させるために、PBG は Bn 個のサンプリングされたソース ノードまたはターゲット ノードの単一バッチを使用して、複数の負の例を構築します。通常の設定では、PBG はトレーニング セットから B = 1000 個の正の例のバッチを取得し、それらを 50 個のエッジのブロックに分割します。各ブロックからのターゲット(ソースに相当)埋め込みは、末尾のエンティティ タイプから均一にサンプリングされた 50 個の埋め込みと連結されます。 50 個の正例と 200 個のサンプリング ノードの外積は、9900 個の負例に等しくなります。

バッチネガティブサンプリング方式は、モデルのトレーニング速度に直接影響を与える可能性があります。バッチ処理を行わない場合、トレーニングの速度は負の例の数に反比例します。バッチトレーニングにより方程式を改善し、安定したトレーニング速度を得ることができます。

Facebook は、LiveJournal、Twitter データ、YouTube ユーザー インタラクション データなどのさまざまなデータセットを使用して PBG を評価しました。さらに、PBG は、1 億 2,000 万を超えるノードと 27 億のエッジを含む Freebase ナレッジ グラフを使用してベンチマークされました。また、Freebase の小さなサブセットである FB15k でもテストしました。FB15k には 15,000 個のノードと 600,000 個のエッジが含まれており、マルチリレーション埋め込み方法のベンチマークとしてよく使用されます。 FB15k 実験では、PBG が現在の最良のグラフ埋め込みモデルと同様のパフォーマンスを発揮することが示されています。ただし、完全な Freebase データセットで評価すると、PBG はメモリ消費において 88% の改善を達成します。

PBG は、数十億のノードと数兆のエッジを含むグラフをトレーニングおよび処理できる最初のスケーラブルな方法です。 PBG の最初の実装はオープンソース化されており、将来的にはさらに興味深い貢献が出てくるでしょう。

<<:  Google、少ないパラメータでテキスト分類を行う新モデル「pQRNN」を発表、BERTに匹敵する性能

>>:  AI起業家にとって、これら4つの新たな方向性は注目に値するかもしれない

ブログ    

推薦する

図 | 武術の観点から STL ソート アルゴリズムの秘密を探る

[[410325]]この記事はWeChatの公開アカウント「Backend Research Ins...

最も強力なオープンソースのマルチモーダル生成モデル MM-Interleaved: 最初の機能同期装置

AI がチャットできるだけでなく、「目」を持ち、絵を理解し、絵を描くことで自分自身を表現することさえ...

...

保存しておくべき機械学習チートシート 27 選

機械学習にはさまざまな側面があり、調査を始めたときに、特定のトピックの要点を簡潔にリストしたさまざま...

...

マイクロソフト、感情分析技術の販売を中止し、顔認識ツールの使用を制限

マイクロソフトは、人工知能システムのためのより責任ある枠組みを構築する取り組みの一環として、画像分析...

持続可能な都市計画とスマートシティに人工知能を活用する方法

21 世紀の急速な都市化は、交通渋滞や汚染から住宅不足や公共サービスの逼迫まで、数多くの課題をもたら...

...

...

データマイニングの10の主要なアルゴリズムを、初心者でも一目で理解できるように平易な言葉で説明しました。

優秀なデータ アナリストは、基本的な統計、データベース、データ分析方法、考え方、データ分析ツールのス...

...

上位 10 の古典的なソートアルゴリズムの JS バージョン

序文読者は自分で試してみることができます。ソースコードはここ (https://github.com...

現在人類社会が直面している人工知能のセキュリティ問題!

現在、人類社会が直面している人工知能のセキュリティ問題は、人工知能のアルゴリズムとシステムの特性によ...