スクラッチから実装する拡散モデル: 新しい理論的視点
(chenyang.co)- 拡散モデルは画像生成を超えて、音声・動画・3D・タンパク質設計・ロボット経路計画のような、多峰性分布のサンプリングが必要な問題に使われており、このチュートリアルは最適化の観点から学習とサンプリングを結び付ける
- 学習過程では、データにノイズを混ぜた (x_\sigma=x_0+\sigma\epsilon) を作り、ニューラルネットワーク (\epsilon_\theta(x,\sigma)) がノイズ方向を予測するよう平均二乗誤差を最小化する
- 学習済み denoiser はデータ集合 (\mathcal{K}) への近似射影として解釈でき、理想的 denoiser は (\sigma)-平滑化された二乗距離関数の勾配と結び付く
- DDIM サンプリングは (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) に対する近似勾配降下とみなせ、(\sigma_t) スケジュールが反復回数と denoiser 評価コストを決める
- 勾配推定更新とノイズ追加を組み合わせると、DDIM、DDPM、著者らの改良サンプラーを
gam・muパラメータでまとめて扱え、toy モデルと latent diffusion の例へとつながる
最適化の観点から見た拡散モデル
- 拡散モデルは多峰性分布からサンプルを生成するのに強みを持ち、Stable Diffusion のようなテキスト画像生成ツールだけでなく、音声、動画、3D 生成、タンパク質設計、ロボット経路計画にも応用されている
- チュートリアルの理論的基盤は、ICML 2024 論文 と 関連論文 における最適化的解釈である
- 実装は主に
smalldiffusionを参照しており、本文のコードは元ライブラリより教育用に単純化されている
学習: ノイズ方向の予測
- 拡散モデルは学習例からデータ集合 (\mathcal{K}) を学び、その集合からサンプルを生成することを目指す
- 画像であれば (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) は現実的な画像に対応するピクセル値の集合である
- 同じ枠組みは音声、動画、ロボット軌道、テキストのような離散領域にも適用できる
- 学習手順は 3 段階で見られる
- (x_0 \sim \mathcal{K}), (\sigma), (\epsilon \sim N(0,I)) をサンプリングする
- (x_\sigma=x_0+\sigma\epsilon) によりノイズが混ざったデータを作る
- (\epsilon_\theta(x_\sigma,\sigma)) が (\epsilon) を予測するよう二乗損失を最小化する
- コードでは
training_loopが各バッチx0ごとにgenerate_train_sampleでsigmaとepsを作り、model(x0 + sigma * eps, sigma)の出力とepsの間の MSE を最適化する - (\sigma) は連続区間から一様サンプリングするのではなく、(N) 個の値に離散化した**(\sigma) スケジュール**から選ぶ
Scheduleクラスは利用可能なsigmasの一覧をラップし、学習中にバッチごとに値をサンプリングする- 本文の例では
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10)を使う ScheduleDDPMはピクセル空間の拡散モデル、ScheduleLDMは Stable Diffusion のような latent diffusion モデル向けのスケジュールである
Swissroll の toy 例
- toy データセットは初期の拡散論文の一つである Sohl-Dickstein et al. 2015 で使われた渦巻き状の点集合で、(\mathcal{K}\subset\mathbb{R}^2) である
- 単純なデータセットでは denoiser を MLP で実装する
- 入力は (x\in\mathbb{R}^2) と (\sigma) の 2 次元埋め込みを連結した値である
- 出力はノイズ (\epsilon\in\mathbb{R}^2) の予測値である
- 多くの拡散モデルは (\sigma) に sinusoidal positional embedding を使うが、この例では単純な 2 次元埋め込みでも十分に機能する
- 例の学習設定では
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)とepochs=15000を使う - 学習済み denoiser は (x-\sigma\epsilon_\theta(x,\sigma)) を描くことでベクトル場として可視化できる
- (\sigma) が大きいとき、denoiser はデータ平均を予測する傾向がある
- (\sigma) が低く、入力 (x) がデータに近いと、実際のデータ点を予測する
Denoising を射影として解釈する
- データ集合 (\mathcal{K}) に対する距離関数は (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}}) と定義される
- (x) の射影 (\mathrm{proj}_{\mathcal{K}}(x)) は、この距離を達成する (\mathcal{K}) 内の点の集合である
- (\mathcal{K}) が閉集合で、(x\notin\mathcal{K}) かつ射影が一意なら、二乗距離関数の勾配は (x-\mathrm{proj}_{\mathcal{K}}(x)) になる
- 距離関数 (\mathrm{dist}_{\mathcal{K}}) は至る所で微分可能ではないため、
minの代わりに softmin を使って (\sigma) で平滑化した二乗距離関数を導入する - 平滑化された距離関数の勾配は、(x) が定める重みに応じて (\mathcal{K}) の点の加重平均の方向を向く
理想的 denoiser と相対誤差モデル
- 理想的 denoiser (\epsilon^*) は、特定の (\sigma) において学習損失を厳密に最小化する denoiser である
- データが有限集合 (\mathcal{K}) 上の離散一様分布なら、理想的 denoiser は閉形式で表される
- 各データ点の重みは (x_\sigma) とその点との距離に応じて決まる
- 小さなデータセットでは
IdealDenoiserで直接計算できる
- toy データでは、理想的 denoiser は (\sigma) が大きいときデータ平均へ向かい、(\sigma) が小さいとき最も近いデータ点へ向かう
- 核心となる定理は、すべての (\sigma>0), (x\in\mathbb{R}^n) に対して (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) という関係を与える
- 相対誤差モデルは、(x-\sigma\epsilon_\theta(x,\sigma)) が (\mathrm{proj}_{\mathcal{K}}(x)) をよく近似する条件を使う
- (\sqrt{n}\sigma) が (\mathrm{dist}_{\mathcal{K}}(x)) を定数倍の範囲でうまく推定するときに適用される
- 誤差は (\eta\mathrm{dist}_{\mathcal{K}}(x)) 以下に制限されると仮定する
- 低ノイズでは manifold hypothesis のもとで追加ノイズの大半がデータ多様体に直交するため、denoising は射影を近似する
- 高ノイズでは (\sigma) が (\mathcal{K}) の直径より大きければ、データの加重平均を予測する denoiser でも相対誤差は小さい
- CIFAR-10 は理想的 denoiser を計算可能な規模であり、実験ではサンプリング軌跡上の正確な射影と理想的 denoiser の出力との相対誤差が小さいことが示される
サンプリング: 反復 denoising と DDIM
- 学習済み denoiser があれば、ノイズが混ざった (x_t) とノイズ水準 (\sigma_t) から (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t)) により (x_0) を予測する
- 開始点では (\sigma_T) を (\mathcal{K}) の直径に対して十分大きく取り、(x_T) を (N(0,\sigma_T)) から独立にサンプリングして (\mathcal{K}) から遠く離れた位置に置く
- 高ノイズでは 1 回の denoiser 呼び出しは相対誤差が小さくても絶対誤差が大きいことがあり、理想的 denoiser の予測はデータ平均に近い
- そのためサンプリングでは (\sigma_t) スケジュールに沿って denoiser を繰り返し呼び出し、(x_T,\ldots,x_0) の系列を作る
- 更新式 (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) は、座標変換を施した決定論的 DDIM サンプリングアルゴリズムと同じである
- DDIM との同値性の証明は論文の Appendix A にある
距離最小化として見た DDIM
- DDIM は (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) に対する近似勾配降下として解釈できる
- ステップサイズは (1-\sigma_{t-1}/\sigma_t) である
- (\nabla f(x_t)) は (\epsilon_\theta(x_t,\sigma_t)) により推定される
- (\sigma_t) スケジュールは、サンプリング中の勾配ステップの回数と大きさを決める
- ステップが少なすぎると (\mathrm{dist}_{\mathcal{K}}(x_t)) が減らず、収束しない可能性がある
- 小さなステップを多数使うと denoiser の評価回数が増え、計算コストが大きくなる
- admissible schedule は、各反復で (\sqrt{n}\sigma_t) が (\mathrm{dist}_{\mathcal{K}}(x_t)) と定数倍の範囲で一致するようなスケジュールである
- 幾何級数的に減少する log-linear な (\sigma_t) 列は admissible schedule である
- 定理によれば、DDIM で生成された (x_t) において (\nabla\mathrm{dist}{\mathcal{K}}(x)) が存在し、(\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T) なら、(x_t) は二乗距離関数の勾配降下によって生成され、(\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t) が保たれる
- toy 例では、元の log-linear スケジュールから部分サンプリングして 20 ステップの DDIM サンプラーを実装しており、多くのサンプルは元データに近いが、改善の余地が残る
勾配推定ベースの改良サンプラー
- (\nabla\mathrm{dist}{\mathcal{K}}(x)) が (x) と (\mathrm{proj}{\mathcal{K}}(x)) の間で不変である点を利用し、現在の推定と前の推定を混ぜる更新を使う
- 更新式 (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) は、前ステップの誤差を現在の推定で補正する方法である
- toy モデルのサンプルでは、この方法は DDIM より速く収束し、サンプルも元データにより近くなる
- DDIM と比べると、このサンプラーはモメンタムを追加したものとして解釈でき、軌跡が overshoot することはあるが、より速く収束しうる
- 生成過程でノイズを追加すると、サンプリング品質が経験的に改善する
- 元の (\sigma_t) スケジュールを保つには、より小さい (\sigma_{t'}) まで denoise した後、(w_t\sim N(0,I)) ノイズを再び加える
- (\mu=\frac{1}{2}) のとき DDPM sampler を正確に復元する
- 全体の更新式 (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) は 3 つのサンプラーを一般化する
- DDIM:
gam=1, mu=0 - DDPM:
gam=1, mu=0.5 - 勾配推定サンプラー:
gam=2, mu=0
- DDIM:
より大きなモデルと参考資料
- ここまでの学習コードは toy データだけでなく、画像拡散モデルをゼロから学習するのにも使える
- FashionMNIST 例 は FashionMNIST データセットで学習し、Papers with Code リーダーボード の FID で 2 位のスコアを得る例として提供されている
- サンプリングコードは修正なしで事前学習済み latent diffusion モデルにも利用できる
- 例では
ScheduleLDM(1000)とModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base')を使う - テキスト条件は
An astronaut riding a horseに設定し、50 個の (\sigma) ステップでサンプリングした後に latent をデコードする
- 例では
- (\gamma) モメンタム項の効果は、高解像度テキスト画像生成での比較可視化によって示される
- 追加で見ておくとよい資料
- What are diffusion models: Markov process を逆にたどる離散時間観点の拡散モデル紹介
- Generative modeling by estimating gradients of the data distribution: 確率微分方程式を逆にたどる連続時間観点の拡散モデル紹介
- The annotated diffusion model: PyTorch 拡散モデル実装の詳細解説
1件のコメント
Hacker Newsのコメント
質問があれば答えられます
特に軌道の議論が良かったです。スケジューラのような話題で多くの人がつまずく部分を理解する動機になるからです。SongやLilianの記事ほど完全ではないとしても、ずっと取っつきやすいので他の人にも勧めようと思います
ちなみに友人が以前書いた最小限の拡散実装があり、DDPMの観点ではもう少し「完全」に近いので役に立ちました: https://github.com/VSehwag/minimal-diffusion/
Stable Diffusionでサンプリング手順を少し試した立場としては、DDIMに対する収束時間とステップ数の比較も見てみたかったです。モメンタム、収束、誤差の間に関係があるのか気になります。たとえばモメンタムサンプラ16ステップが、DDIM 20ステップ ± 誤差項とほぼ同等なのか、といった比較があるとよいと思います
get_sigma_embeds(batches, sigma)は最初の入力を使っていないように見えます。sigmaを(batches, 1)の形にブロードキャストする意図だったのでしょうか数学的な細部をはるかに深く扱いつつ、500行未満の非常に分かりやすい最小実装も付いています
Soraや他の動画生成モデルを動かしている拡散トランスフォーマー版にも発展するといいですね。この記事と https://jaykmody.com/blog/gpt-from-scratch/ を組み合わせれば、「ゼロから作る拡散トランスフォーマー」の入門記事が作れそうです
一方で本当に深く掘り下げたいなら、Kingma、Gao、Ricky Tian Qi Chen、Max Wellingの弟子たち(Tomczakはポスドク、Hoogeboomなど)、そして陰の功労者であるAapo Hyvärinenの仕事を読むことを勧めます。Kingma & Gaoの比較的軽めの仕事で、かつSD3論文にも関係する例はこちらです: https://arxiv.org/abs/2303.00848
残念なのは、過去の研究を知って理解していることへの依存が大きく、アクセスしにくい点です。ただ、これを意味のある批判と呼ぶのは難しい面もあります。研究であって、大衆向けの教育資料ではないからです
n_embdサイズのベクトルに射影すればよく、拡散過程そのものはそのままにできます[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
私たちの観点では、拡散モデルが学習しやすいのは、正確な距離関数の勾配を予測する代わりに、平滑化された距離関数の勾配を予測する学習目標を使っているからです。拡散モデルのサンプリングは、近似的な勾配ステップを何度も踏むのに似ています
拡散モデルをより深く理解したいなら、こうしたブログ記事をすべて読み、異なる解釈を学んでみることを勧めます
ただしこの記事のアプローチは、ノイズ除去器の誤差解析のような、より興味深い実験を可能にしてくれそうです
[1] https://arxiv.org/pdf/2305.03486.pdf
たとえば画像生成器がピアノの鍵盤を作るのが難しいのはなぜでしょうか。黒鍵が2本と3本で交互に並ぶ構造を作るには、中距離の制約をもっとよく表現する必要があるように見えます