Stable Diffusion は 11 か月前に誕生し、消費者向け GPU で実行できるというニュースは多くの研究者を勇気づけました。それだけでなく、Apple の担当者はすぐに介入し、Stable Diffusion を iPhone、iPad、Mac に「詰め込んだ」のです。これにより、Stable Diffusion のハードウェア要件が大幅に削減され、徐々に誰もが使用できる「ブラック テクノロジー」になります。 今では、Raspberry Pi Zero 2 でも動作します。 写真 Raspberry Pi Zero 2 「同じくらい小さい。5倍速い。」 これはどのようなコンセプトですか? Stable Diffusion を実行するのは簡単な作業ではありません。10 億のパラメータを持つ大規模な Transformer モデルが含まれており、推奨される最小 RAM/VRAM は通常 8 GB です。 RPI Zero 2 は、512MB のメモリを搭載したマイクロコンピュータです。 つまり、RPI Zero 2 で Stable Diffusion を実行するのは非常に困難です。さらに、著者らは実行中にストレージスペースを増やしたり、中間結果をディスクにオフロードしたりしませんでした。 一般に、主要な機械学習フレームワークとライブラリは、推論の遅延を最小限に抑えることやスループットを最大化することに重点を置いていますが、これらはすべてメモリ使用量を犠牲にして行われます。そこで著者は、メモリ消費を最小限に抑えることに重点を置いた、超小型の逆アセンブル推論ライブラリを作成することにしました。 OnnxStream がそれを実現します。
プロジェクトアドレス: https://github.com/vitoplantamura/OnnxStream OnnxStream は、WeightsProvider から派生したクラスであるモデルの重みを提供するコンポーネントから推論エンジンを分離するというアイデアに基づいています。 WeightsProvider の特殊化により、あらゆる種類のモデル パラメータの読み込み、キャッシュ、およびプリフェッチを実装できます。たとえば、カスタム WeightsProvider は、ディスクに何もロードまたは書き込まずに、HTTP サーバーから直接データをダウンロードすることを決定する場合があります (これが、OnnxStream の名前に Stream が含まれている理由です)。使用可能なデフォルトの WeightsProviders は、DiskNoCache と DiskPrefetch の 2 つです。 Microsoft の推論フレームワーク OnnxStream と比較すると、OnnxStream は同じ効果を得るために 1/55 のメモリしか消費しませんが、速度 (CPU 上) は前者よりも 0.5 ~ 2 倍しか遅くありません。 次に、RPI Zero 2 で実行されている Stable Diffusion の効果とその背後にある方法について説明します。これは遅いですが、より小型で制限のあるデバイスで大規模なモデルを実行するという斬新な試みであることに留意することが重要です。 写真 ネットユーザーはこのプロジェクトがクールだと思っている Raspberry Pi Zero 2 で安定した拡散を実行するVAE デコーダーは、Stable Diffusion の中で、単精度または半精度で RPI Zero 2 RAM に適合できない唯一のモデルです。これは、モデル内に残差接続、非常に大きなテンソル、および畳み込みが存在するためです。唯一の解決策は静的量子化(8 ビット)です。 以下の画像は、著者のリポジトリに含まれている Stable Diffusion サンプル実装によって、さまざまな精度の VAE デコーダーを備えた OnnxStream を使用して生成されました。 最初の画像は、RPI Zero 2 によって生成されたのと同じ潜在変数を使用して、著者の PC で生成されました。 精度W16A16のVAEデコーダーの生成結果 W8A32の精度を持つVAEデコーダーの生成結果 3 番目の画像は、RPI Zero 2 によって約 3 時間で生成されました。 図1: W8A8の精度を持つVAEデコーダーの生成効果 OnnxStreamの特徴
また、OnnxStream は、MatMul、Convolution、要素ごとの Add/Sub/Mul/Div、Sigmoid、Softmax などの特定のプリミティブを高速化するために XNNPACK に依存していることに注意してください。 パフォーマンス比較Stable Diffusion は、テキスト エンコーダー (672 の操作と 1 億 2,300 万のパラメーター)、UNET モデル (2,050 の操作と 8 億 5,400 万のパラメーター)、VAE デコーダー (276 の操作と 4,900 万のパラメーター) の 3 つのモデルで構成されています。 バッチ サイズが 1 であると仮定すると、完全な画像を生成するには 10 ステップが必要であり、良好な結果を得るには、テキスト エンコーダーを 2 回実行し、UNET モデルを 20 回 (つまり 2 * 10)、VAE デコーダーを 1 回実行する必要があります (Euler Ancestral スケジューラを使用)。 この表には、安定拡散の 3 つのモデルの異なる推論時間と、メモリ消費量 (Windows のピーク ワーキング セット サイズまたは Linux の最大常駐セット サイズ) が示されています。 UNET モデル (FP16 精度で実行する場合、OnnxStream で FP16 演算が有効) では、OnnxStream のメモリ消費量は OnnxRuntime の 1/55 に過ぎず、速度は 0.5 ~ 2 倍しか遅くないことがわかります。 このテストに関して注意すべき点は次のとおりです。
注意スライスと量子化UNET モデルを実行する際には、「アテンション スライシング」手法を採用し、VAE デコーダーに W8A8 量子化を使用しました。これは、モデルのメモリ消費を RPI Zero 2 での実行に適したレベルまで削減するために重要でした。 インターネット上には量子化ニューラル ネットワークに関する情報は多数ありますが、「アテンション スライシング」に関する情報はほとんどありません。 ここでの考え方は単純です。目標は、UNET モデル内のさまざまなマルチヘッド アテンションのスケールされたドット積アテンションを計算するときに、完全な Q@K^T マトリックスを生成しないようにすることです。 UNETモデルでは、アテンションヘッドの数が8の場合、Qの形状は(8,4096,40)、K^Tの形状は(8,40,4096)になります。したがって、最初の MatMul の最終的な形状は (8,4096,4096) となり、これは 512MB のテンソル (FP32 精度) になります。 写真 解決策は、Q を垂直に分割し、各 Q ブロックに対して通常どおりアテンション操作を実行することです。 Q_sliced の形状は (1,x,40) で、x は 4096 (この場合は) であり、onnxstream::Model::m_attention_fused_ops_parts (デフォルトは 2 ですが、カスタマイズできます) で割られます。 この簡単なトリックにより、FP32 精度で実行する場合、UNET モデルの全体的なメモリ消費量を 1.1 GB から 300 MB に削減できます。より効率的な代替手段は FlashAttention を使用することですが、FlashAttention では、著者が示した例の XnnPack をバイパスして、サポートされているアーキテクチャ (AVX、NEON など) ごとにカスタム カーネルを作成する必要があります。 詳細については、プロジェクトの GitHub ページを参照してください。 |
<<: ImageNet-1K 圧縮 20 倍、Top-1 精度が初めて 60% を超える: 大規模データセット蒸留の転換点
LLM ロングコンテキスト モデルの究極のソリューションは何ですか?プリンストン大学とMeta AI...
AIは銀行の顧客サービスの性質を変える銀行やその他の金融機関は、コールセンターからチャットボット、よ...
LEACH プロトコルについてはあまり知られていないかもしれません。このプロトコルの説明は、低電力適...
人工知能はここ2年で急速に発展し、狂気のレベルにまで達しました。例えば、ロボットは人間社会の「市民」...
人工知能(AI)技術の環境への影響は最近、幅広い注目を集めていますが、これは今後10年間でAIの中心...
この記事はLeiphone.comから転載したものです。転載する場合は、Leiphone.com公式...
人工知能 (AI) は、コンピューターや機械をインテリジェントに動作させる方法を研究する分野です。機...
1. k-meansアルゴリズムの紹介: k-means アルゴリズムは入力量 k を受け取り、n ...
[[428302]] 2021年9月26日にarXivにアップロードされた論文「人間のガイダンスによ...
[[377893]] [51CTO.com クイック翻訳] データとオープンソースの機械学習フレーム...
[[434349]]この記事はAI新メディアQuantum Bit(公開アカウントID:QbitA...