Raspberry Pi で Stable Diffusion を実行すると、260 MB の RAM に 10 億のパラメータ モデルが「保持」されます。

Raspberry Pi で Stable Diffusion を実行すると、260 MB の RAM に 10 億のパラメータ モデルが「保持」されます。

Stable Diffusion は 11 か月前に誕生し、消費者向け GPU で実行できるというニュースは多くの研究者を勇気づけました。それだけでなく、Apple の担当者はすぐに介入し、Stable Diffusion を iPhone、iPad、Mac に「詰め込んだ」のです。これにより、Stable Diffusion のハードウェア要件が大幅に削減され、徐々に誰もが使用できる「ブラック テクノロジー」になります。

今では、Raspberry Pi Zero 2 でも動作します。

写真

Raspberry Pi Zero 2 「同じくらい小さい。5倍速い。」

これはどのようなコンセプトですか? Stable Diffusion を実行するのは簡単な作業ではありません。10 億のパラメータを持つ大規模な Transformer モデルが含まれており、推奨される最小 RAM/VRAM は通常 8 GB です。 RPI Zero 2 は、512MB のメモリを搭載したマイクロコンピュータです。

つまり、RPI Zero 2 で Stable Diffusion を実行するのは非常に困難です。さらに、著者らは実行中にストレージスペースを増やしたり、中間結果をディスクにオフロードしたりしませんでした。

一般に、主要な機械学習フレームワークとライブラリは、推論の遅延を最小限に抑えることやスループットを最大化することに重点を置いていますが、これらはすべてメモリ使用量を犠牲にして行われます。そこで著者は、メモリ消費を最小限に抑えることに重点を置いた、超小型の逆アセンブル推論ライブラリを作成することにしました。

OnnxStream がそれを実現します。

 

プロジェクトアドレス: https://github.com/vitoplantamura/OnnxStream

OnnxStream は、WeightsProvider から派生したクラスであるモデルの重みを提供するコンポーネントから推論エンジンを分離するというアイデアに基づいています。 WeightsProvider の特殊化により、あらゆる種類のモデル パラメータの読み込み、キャッシュ、およびプリフェッチを実装できます。たとえば、カスタム WeightsProvider は、ディスクに何もロードまたは書き込まずに、HTTP サーバーから直接データをダウンロードすることを決定する場合があります (これが、OnnxStream の名前に Stream が含まれている理由です)。使用可能なデフォルトの WeightsProviders は、DiskNoCache と DiskPrefetch の 2 つです。

Microsoft の推論フレームワーク OnnxStream と比較すると、OnnxStream は同じ効果を得るために 1/55 のメモリしか消費しませんが、速度 (CPU 上) は前者よりも 0.5 ~ 2 倍しか遅くありません。

次に、RPI Zero 2 で実行されている Stable Diffusion の効果とその背後にある方法について説明します。これは遅いですが、より小型で制限のあるデバイスで大規模なモデルを実行するという斬新な試みであることに留意することが重要です。

写真

ネットユーザーはこのプロジェクトがクールだと思っている

Raspberry Pi Zero 2 で安定した拡散を実行する

VAE デコーダーは、Stable Diffusion の中で、単精度または半精度で RPI Zero 2 RAM に適合できない唯一のモデルです。これは、モデル内に残差接続、非常に大きなテンソル、および畳み込みが存在するためです。唯一の解決策は静的量子化(8 ビット)です。

以下の画像は、著者のリポジトリに含まれている Stable Diffusion サンプル実装によって、さまざまな精度の VAE デコーダーを備えた OnnxStream を使用して生成されました。

最初の画像は、RPI Zero 2 によって生成されたのと同じ潜在変数を使用して、著者の PC で生成されました。

精度W16A16のVAEデコーダーの生成結果

W8A32の精度を持つVAEデコーダーの生成結果

3 番目の画像は、RPI Zero 2 によって約 3 時間で生成されました。

図1: W8A8の精度を持つVAEデコーダーの生成効果

OnnxStreamの特徴

  • 推論エンジンをWeightsProviderから分離する
  • WeightsProviderはDiskNoCache、DiskPrefetch、またはカスタムにすることができます
  • 注目スライス
  • 動的量子化(8 ビット符号なし、非対称、パーセンタイル)
  • 静的量子化 (W8A8 符号なし、非対称、パーセンタイル)
  • 量子化モデルを簡単に調整
  • FP16 をサポート (FP16 操作の有無にかかわらず)
  • 24 個の ONNX 演算子 (最もよく使用される演算子) を実装しました
  • 操作は順番に実行されますが、すべての演算子はマルチスレッドです。
  • 単一の実装ファイル + ヘッダーファイル
  • XNNPACK 呼び出しは XnnPack クラスにカプセル化されます (将来の置き換え用)

また、OnnxStream は、MatMul、Convolution、要素ごとの Add/Sub/Mul/Div、Sigmoid、Softmax などの特定のプリミティブを高速化するために XNNPACK に依存していることに注意してください。

パフォーマンス比較

Stable Diffusion は、テキスト エンコーダー (672 の操作と 1 億 2,300 万のパラメーター)、UNET モデル (2,050 の操作と 8 億 5,400 万のパラメーター)、VAE デコーダー (276 の操作と 4,900 万のパラメーター) の 3 つのモデルで構成されています。

バッチ サイズが 1 であると仮定すると、完全な画像を生成するには 10 ステップが必要であり、良好な結果を得るには、テキスト エンコーダーを 2 回実行し、UNET モデルを 20 回 (つまり 2 * 10)、VAE デコーダーを 1 回実行する必要があります (Euler Ancestral スケジューラを使用)。

この表には、安定拡散の 3 つのモデルの異なる推論時間と、メモリ消費量 (Windows のピーク ワーキング セット サイズまたは Linux の最大常駐セット サイズ) が示されています。

UNET モデル (FP16 精度で実行する場合、OnnxStream で FP16 演算が有効) では、OnnxStream のメモリ消費量は OnnxRuntime の 1/55 に過ぎず、速度は 0.5 ~ 2 倍しか遅くないことがわかります。

このテストに関して注意すべき点は次のとおりです。

  • OnnxRuntime の最初の実行はウォームアップ推論です。これは、最初の実行の前に InferenceSession が作成され、その後のすべての実行で再利用されるためです。 OnnxStream は純粋に「積極的」に設計されているため、事前ウォームアップされた推論はありません (ただし、後続の実行では、オペレーティング システムの重みファイルのキャッシュの恩恵を受けることができます)。
  • 現在、OnnxStream はバッチ サイズ != 1 の入力をサポートしていません。これは、バッチ サイズ = 2 を使用して UNET モデルを実行するときに拡散プロセス全体を大幅に高速化できる OnnxRuntime とは異なります。
  • テストでは、OnnxRuntime の SessionOptions (EnableCpuMemArena や ExecutionMode など) を変更しても、結果に目立った影響はありませんでした。
  • メモリ消費量と推論時間に関して、OnnxRuntime のパフォーマンスは NCNN (別のフレームワーク) と非常に似ています。
  • Windows Server 2019、16GB RAM、8750H CPU (AVX2)、970 EVO Plus SSD、VMWare 上の 8 個の仮想コアでテスト済み。

注意スライスと量子化

UNET モデルを実行する際には、「アテンション スライシング」手法を採用し、VAE デコーダーに W8A8 量子化を使用しました。これは、モデルのメモリ消費を RPI Zero 2 での実行に適したレベルまで削減するために重要でした。

インターネット上には量子化ニューラル ネットワークに関する情報は多数ありますが、「アテンション スライシング」に関する情報はほとんどありません。

ここでの考え方は単純です。目標は、UNET モデル内のさまざまなマルチヘッド アテンションのスケールされたドット積アテンションを計算するときに、完全な Q@K^T マトリックスを生成しないようにすることです。 UNETモデルでは、アテンションヘッドの数が8の場合、Qの形状は(8,4096,40)、K^Tの形状は(8,40,4096)になります。したがって、最初の MatMul の最終的な形状は (8,4096,4096) となり、これは 512MB のテンソル (FP32 精度) になります。

写真

解決策は、Q を垂直に分割し、各 Q ブロックに対して通常どおりアテンション操作を実行することです。 Q_sliced の形状は (1,x,40) で、x は 4096 (この場合は) であり、onnxstream::Model::m_attention_fused_ops_parts (デフォルトは 2 ですが、カスタマイズできます) で割られます。

この簡単なトリックにより、FP32 精度で実行する場合、UNET モデルの全体的なメモリ消費量を 1.1 GB から 300 MB に削減できます。より効率的な代替手段は FlashAttention を使用することですが、FlashAttention では、著者が示した例の XnnPack をバイパスして、サポートされているアーキテクチャ (AVX、NEON など) ごとにカスタム カーネルを作成する必要があります。

詳細については、プロジェクトの GitHub ページを参照してください。

<<:  ImageNet-1K 圧縮 20 倍、Top-1 精度が初めて 60% を超える: 大規模データセット蒸留の転換点

>>:  なぜ私はLangChainを諦めたのでしょうか?

ブログ    
ブログ    
ブログ    

推薦する

今後 10 年間で人工知能が私たちの生活を支配するようになるとき、携帯電話はどのようなものになるでしょうか?

テクノロジー業界のほとんどの人は、今後 10 年以内にユビキタス テクノロジーが 1 日のあらゆる瞬...

2022 年ソフトウェア エンジニア レポートが公開されました。最も高い年収はサイバーセキュリティ業界、機械学習はNLPに勝てない

2022 年に雇用主の間で最も人気のあるプログラミング言語はどれですか? 地域や職種によってソフトウ...

...

AIに関する哲学的考察 - 認知不変性とAI

米国国防高等研究計画局(DARPA)はかつて、第3波AIの概念を提唱しました。その議論では、第3波A...

わかりやすく解説: 機械学習と統計モデリングの違い

これらは互いに大きく異なっており、すべてのデータ サイエンティストはその理由と方法を理解する必要があ...

...

IDCレポート:ジェネレーティブAIは爆発的な産業探査の時代に入り、技術供給側は商業化の初期段階にある

9月22日、IDCコンサルティングの公式WeChatアカウントによると、2023年下半期以降、ますま...

...

指紋、顔、音声認識技術は、本当に簡単に解読できます。

【AI世代編集部注】顔認識は今年、CCTVの315ガラで痛烈に批判された。この技術は人々が安心して...

人工知能にはどのような分野が含まれますか?どのように機能しますか?

現代の産業技術の発展により、私たちの生活は大きく改善されました。新しい家具が次々と登場しています。キ...

...

機械学習: 教師なし学習: 9 つのクラスタリング アルゴリズム

今日は、機械学習の教師なし学習における一般的なクラスタリング手法をいくつか紹介したいと思います。教師...

...

AI界のお笑い王に100万の賞金!北京郵電大学、南洋理工大学などが「砂像動画」データセットを公開 FunQA:アルゴリズムで人間のユーモアを学習

人は直感に反する動画(ユーモラスで独創的で視覚的に魅力的な動画)から容易に喜びを得ることができます。...

...