2 ポイント 投稿者 GN⁺ 2024-03-12 | 1件のコメント | WhatsAppで共有
  • 拡散モデルは画像生成を超えて、音声・動画・3D・タンパク質設計・ロボット経路計画のような、多峰性分布のサンプリングが必要な問題に使われており、このチュートリアルは最適化の観点から学習とサンプリングを結び付ける
  • 学習過程では、データにノイズを混ぜた (x_\sigma=x_0+\sigma\epsilon) を作り、ニューラルネットワーク (\epsilon_\theta(x,\sigma)) がノイズ方向を予測するよう平均二乗誤差を最小化する
  • 学習済み denoiser はデータ集合 (\mathcal{K}) への近似射影として解釈でき、理想的 denoiser は (\sigma)-平滑化された二乗距離関数の勾配と結び付く
  • DDIM サンプリングは (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) に対する近似勾配降下とみなせ、(\sigma_t) スケジュールが反復回数と denoiser 評価コストを決める
  • 勾配推定更新とノイズ追加を組み合わせると、DDIM、DDPM、著者らの改良サンプラーを gammu パラメータでまとめて扱え、toy モデルと latent diffusion の例へとつながる

最適化の観点から見た拡散モデル

  • 拡散モデルは多峰性分布からサンプルを生成するのに強みを持ち、Stable Diffusion のようなテキスト画像生成ツールだけでなく、音声、動画、3D 生成、タンパク質設計、ロボット経路計画にも応用されている
  • チュートリアルの理論的基盤は、ICML 2024 論文関連論文 における最適化的解釈である
  • 実装は主に smalldiffusion を参照しており、本文のコードは元ライブラリより教育用に単純化されている

学習: ノイズ方向の予測

  • 拡散モデルは学習例からデータ集合 (\mathcal{K}) を学び、その集合からサンプルを生成することを目指す
    • 画像であれば (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) は現実的な画像に対応するピクセル値の集合である
    • 同じ枠組みは音声、動画、ロボット軌道、テキストのような離散領域にも適用できる
  • 学習手順は 3 段階で見られる
    • (x_0 \sim \mathcal{K}), (\sigma), (\epsilon \sim N(0,I)) をサンプリングする
    • (x_\sigma=x_0+\sigma\epsilon) によりノイズが混ざったデータを作る
    • (\epsilon_\theta(x_\sigma,\sigma)) が (\epsilon) を予測するよう二乗損失を最小化する
  • コードでは training_loop が各バッチ x0 ごとに generate_train_samplesigmaeps を作り、model(x0 + sigma * eps, sigma) の出力と eps の間の MSE を最適化する
  • (\sigma) は連続区間から一様サンプリングするのではなく、(N) 個の値に離散化した**(\sigma) スケジュール**から選ぶ
    • Schedule クラスは利用可能な sigmas の一覧をラップし、学習中にバッチごとに値をサンプリングする
    • 本文の例では ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) を使う
    • ScheduleDDPM はピクセル空間の拡散モデル、ScheduleLDM は Stable Diffusion のような latent diffusion モデル向けのスケジュールである

Swissroll の toy 例

  • toy データセットは初期の拡散論文の一つである Sohl-Dickstein et al. 2015 で使われた渦巻き状の点集合で、(\mathcal{K}\subset\mathbb{R}^2) である
  • 単純なデータセットでは denoiser を MLP で実装する
    • 入力は (x\in\mathbb{R}^2) と (\sigma) の 2 次元埋め込みを連結した値である
    • 出力はノイズ (\epsilon\in\mathbb{R}^2) の予測値である
    • 多くの拡散モデルは (\sigma) に sinusoidal positional embedding を使うが、この例では単純な 2 次元埋め込みでも十分に機能する
  • 例の学習設定では ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)epochs=15000 を使う
  • 学習済み denoiser は (x-\sigma\epsilon_\theta(x,\sigma)) を描くことでベクトル場として可視化できる
    • (\sigma) が大きいとき、denoiser はデータ平均を予測する傾向がある
    • (\sigma) が低く、入力 (x) がデータに近いと、実際のデータ点を予測する

Denoising を射影として解釈する

  • データ集合 (\mathcal{K}) に対する距離関数は (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}}) と定義される
  • (x) の射影 (\mathrm{proj}_{\mathcal{K}}(x)) は、この距離を達成する (\mathcal{K}) 内の点の集合である
  • (\mathcal{K}) が閉集合で、(x\notin\mathcal{K}) かつ射影が一意なら、二乗距離関数の勾配は (x-\mathrm{proj}_{\mathcal{K}}(x)) になる
  • 距離関数 (\mathrm{dist}_{\mathcal{K}}) は至る所で微分可能ではないため、min の代わりに softmin を使って (\sigma) で平滑化した二乗距離関数を導入する
  • 平滑化された距離関数の勾配は、(x) が定める重みに応じて (\mathcal{K}) の点の加重平均の方向を向く

理想的 denoiser と相対誤差モデル

  • 理想的 denoiser (\epsilon^*) は、特定の (\sigma) において学習損失を厳密に最小化する denoiser である
  • データが有限集合 (\mathcal{K}) 上の離散一様分布なら、理想的 denoiser は閉形式で表される
    • 各データ点の重みは (x_\sigma) とその点との距離に応じて決まる
    • 小さなデータセットでは IdealDenoiser で直接計算できる
  • toy データでは、理想的 denoiser は (\sigma) が大きいときデータ平均へ向かい、(\sigma) が小さいとき最も近いデータ点へ向かう
  • 核心となる定理は、すべての (\sigma>0), (x\in\mathbb{R}^n) に対して (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) という関係を与える
  • 相対誤差モデルは、(x-\sigma\epsilon_\theta(x,\sigma)) が (\mathrm{proj}_{\mathcal{K}}(x)) をよく近似する条件を使う
    • (\sqrt{n}\sigma) が (\mathrm{dist}_{\mathcal{K}}(x)) を定数倍の範囲でうまく推定するときに適用される
    • 誤差は (\eta\mathrm{dist}_{\mathcal{K}}(x)) 以下に制限されると仮定する
    • 低ノイズでは manifold hypothesis のもとで追加ノイズの大半がデータ多様体に直交するため、denoising は射影を近似する
    • 高ノイズでは (\sigma) が (\mathcal{K}) の直径より大きければ、データの加重平均を予測する denoiser でも相対誤差は小さい
  • CIFAR-10 は理想的 denoiser を計算可能な規模であり、実験ではサンプリング軌跡上の正確な射影と理想的 denoiser の出力との相対誤差が小さいことが示される

サンプリング: 反復 denoising と DDIM

  • 学習済み denoiser があれば、ノイズが混ざった (x_t) とノイズ水準 (\sigma_t) から (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t)) により (x_0) を予測する
  • 開始点では (\sigma_T) を (\mathcal{K}) の直径に対して十分大きく取り、(x_T) を (N(0,\sigma_T)) から独立にサンプリングして (\mathcal{K}) から遠く離れた位置に置く
  • 高ノイズでは 1 回の denoiser 呼び出しは相対誤差が小さくても絶対誤差が大きいことがあり、理想的 denoiser の予測はデータ平均に近い
  • そのためサンプリングでは (\sigma_t) スケジュールに沿って denoiser を繰り返し呼び出し、(x_T,\ldots,x_0) の系列を作る
  • 更新式 (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) は、座標変換を施した決定論的 DDIM サンプリングアルゴリズムと同じである
    • DDIM との同値性の証明は論文の Appendix A にある

距離最小化として見た DDIM

  • DDIM は (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2) に対する近似勾配降下として解釈できる
    • ステップサイズは (1-\sigma_{t-1}/\sigma_t) である
    • (\nabla f(x_t)) は (\epsilon_\theta(x_t,\sigma_t)) により推定される
  • (\sigma_t) スケジュールは、サンプリング中の勾配ステップの回数と大きさを決める
    • ステップが少なすぎると (\mathrm{dist}_{\mathcal{K}}(x_t)) が減らず、収束しない可能性がある
    • 小さなステップを多数使うと denoiser の評価回数が増え、計算コストが大きくなる
  • admissible schedule は、各反復で (\sqrt{n}\sigma_t) が (\mathrm{dist}_{\mathcal{K}}(x_t)) と定数倍の範囲で一致するようなスケジュールである
    • 幾何級数的に減少する log-linear な (\sigma_t) 列は admissible schedule である
  • 定理によれば、DDIM で生成された (x_t) において (\nabla\mathrm{dist}{\mathcal{K}}(x)) が存在し、(\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T) なら、(x_t) は二乗距離関数の勾配降下によって生成され、(\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t) が保たれる
  • toy 例では、元の log-linear スケジュールから部分サンプリングして 20 ステップの DDIM サンプラーを実装しており、多くのサンプルは元データに近いが、改善の余地が残る

勾配推定ベースの改良サンプラー

  • (\nabla\mathrm{dist}{\mathcal{K}}(x)) が (x) と (\mathrm{proj}{\mathcal{K}}(x)) の間で不変である点を利用し、現在の推定と前の推定を混ぜる更新を使う
  • 更新式 (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) は、前ステップの誤差を現在の推定で補正する方法である
  • toy モデルのサンプルでは、この方法は DDIM より速く収束し、サンプルも元データにより近くなる
  • DDIM と比べると、このサンプラーはモメンタムを追加したものとして解釈でき、軌跡が overshoot することはあるが、より速く収束しうる
  • 生成過程でノイズを追加すると、サンプリング品質が経験的に改善する
    • 元の (\sigma_t) スケジュールを保つには、より小さい (\sigma_{t'}) まで denoise した後、(w_t\sim N(0,I)) ノイズを再び加える
    • (\mu=\frac{1}{2}) のとき DDPM sampler を正確に復元する
  • 全体の更新式 (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) は 3 つのサンプラーを一般化する
    • DDIM: gam=1, mu=0
    • DDPM: gam=1, mu=0.5
    • 勾配推定サンプラー: gam=2, mu=0

より大きなモデルと参考資料

  • ここまでの学習コードは toy データだけでなく、画像拡散モデルをゼロから学習するのにも使える
  • FashionMNIST 例 は FashionMNIST データセットで学習し、Papers with Code リーダーボード の FID で 2 位のスコアを得る例として提供されている
  • サンプリングコードは修正なしで事前学習済み latent diffusion モデルにも利用できる
    • 例では ScheduleLDM(1000)ModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') を使う
    • テキスト条件は An astronaut riding a horse に設定し、50 個の (\sigma) ステップでサンプリングした後に latent をデコードする
  • (\gamma) モメンタム項の効果は、高解像度テキスト画像生成での比較可視化によって示される
  • 追加で見ておくとよい資料

1件のコメント

 
GN⁺ 2024-03-12
Hacker Newsのコメント
  • 著者です。拡散モデルを理解しようとしているうちに、コードと数学を大幅に単純化できることに気づき、それでこのブログ記事と拡散ライブラリを作ることになりました
    質問があれば答えられます
    • 研究者の立場からすると気に入らない拡散モデルのブログ記事は多いのですが、これは本当に良かったです。核心にすぐ入る一方で、よく複雑になりがちな部分もきちんと見せていて、道に迷ったり話が散漫になったりしません
      特に軌道の議論が良かったです。スケジューラのような話題で多くの人がつまずく部分を理解する動機になるからです。SongやLilianの記事ほど完全ではないとしても、ずっと取っつきやすいので他の人にも勧めようと思います
      ちなみに友人が以前書いた最小限の拡散実装があり、DDPMの観点ではもう少し「完全」に近いので役に立ちました: https://github.com/VSehwag/minimal-diffusion/
    • 最後のサンプル画像では、モメンタム項が家のデジタルペインティングに悪影響を与えているように見えます。gamma = 2.0 の画像ではドアが消えているので、勾配情報を使うDDIMサンプラの効果を直感的に理解するために、その例の細部がもっと気になります
      Stable Diffusionでサンプリング手順を少し試した立場としては、DDIMに対する収束時間とステップ数の比較も見てみたかったです。モメンタム、収束、誤差の間に関係があるのか気になります。たとえばモメンタムサンプラ16ステップが、DDIM 20ステップ ± 誤差項とほぼ同等なのか、といった比較があるとよいと思います
    • get_sigma_embeds(batches, sigma) は最初の入力を使っていないように見えます。sigma(batches, 1) の形にブロードキャストする意図だったのでしょうか
    • こうした概念の一部は物理の原理に由来するのでしょうか。ニューラルネットワークが生物学的な神経回路を模したと言われるのと似た話なのか、その観点について何か洞察があれば知りたいです
  • もう一つの良い記事も Diffusion Models From Scratch という題名です: https://www.tonyduan.com/diffusion/index.html
    数学的な細部をはるかに深く扱いつつ、500行未満の非常に分かりやすい最小実装も付いています
  • コードがあるのが良いですね。拡散の論文は数式が多いことで有名ですが(https://twitter.com/cto_junior/status/1766518604395155830)、他の多くの人にとってはコードのほうがずっと読みやすく、より正確ですらあるかもしれません。あらゆる理論論文には参照実装コードが付属すべきだと思います
    Soraや他の動画生成モデルを動かしている拡散トランスフォーマー版にも発展するといいですね。この記事と https://jaykmody.com/blog/gpt-from-scratch/ を組み合わせれば、「ゼロから作る拡散トランスフォーマー」の入門記事が作れそうです
    • 拡散の論文が数式だらけで有名なのは確かですが、正直なところ私の知っている拡散研究者たちもだいたい同じ反応をします。多くの人が同じ式を繰り返し書いていて、それらの式は実質的に復習のために近いと思います
      一方で本当に深く掘り下げたいなら、Kingma、Gao、Ricky Tian Qi Chen、Max Wellingの弟子たち(Tomczakはポスドク、Hoogeboomなど)、そして陰の功労者であるAapo Hyvärinenの仕事を読むことを勧めます。Kingma & Gaoの比較的軽めの仕事で、かつSD3論文にも関係する例はこちらです: https://arxiv.org/abs/2303.00848
      残念なのは、過去の研究を知って理解していることへの依存が大きく、アクセスしにくい点です。ただ、これを意味のある批判と呼ぶのは難しい面もあります。研究であって、大衆向けの教育資料ではないからです
    • U-netをトランスフォーマーエンコーダに置き換えるだけで済みます。埋め込みを外し、画像パッチを n_embd サイズのベクトルに射影すればよく、拡散過程そのものはそのままにできます
  • 良い記事ですが、拡散モデルがスコア関数(対数確率の導関数)をモデル化しているという重要な性質[1]と、拡散サンプリングがランジュバン力学[2]に似ている点が欠けている感じがします。これらの見方は、GANより学習しやすい理由をよく説明していると思います。モデリング対象がより簡単だからです
    [1] https://yang-song.net/blog/2021/score/
    [2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
    • その通りです。これらのブログ記事は、本文で説明されていた「データへの射影」という観点とは異なる拡散モデルの解釈を与えています。同じ学習目標とサンプリング過程を解釈する複数の方法として見ることができます
      私たちの観点では、拡散モデルが学習しやすいのは、正確な距離関数の勾配を予測する代わりに、平滑化された距離関数の勾配を予測する学習目標を使っているからです。拡散モデルのサンプリングは、近似的な勾配ステップを何度も踏むのに似ています
      拡散モデルをより深く理解したいなら、こうしたブログ記事をすべて読み、異なる解釈を学んでみることを勧めます
  • とても興味深いです。Iterative alpha-(de)Blending[1] をすぐに思い出しました。この研究も、概念的により単純な拡散モデルを立てようとしており、近似的な反復射影過程として定式化するという結論に至っています
    ただしこの記事のアプローチは、ノイズ除去器の誤差解析のような、より興味深い実験を可能にしてくれそうです
    [1] https://arxiv.org/pdf/2305.03486.pdf
  • 理論の説明は良いです。データセットとは独立した説明のように見えますが、実際の画像生成の具体的な部分が気になります
    たとえば画像生成器がピアノの鍵盤を作るのが難しいのはなぜでしょうか。黒鍵が2本と3本で交互に並ぶ構造を作るには、中距離の制約をもっとよく表現する必要があるように見えます
    • これは指の問題と同じです。数、大きさ、角度、位置などを毎回すべて合わせなければならず、どれか一つでも間違えば人はすぐに気づきます。木の枝のように分岐位置が「間違って」いても人があまり気づかない対象とは違います
  • 拡散のアイデアの一部は、訓練データをものすごく増やすことなのでしょうか。ランダムに拡散させた画像を、元の拡散されていない画像と対比できるようになる、ということでしょうか
  • すべての機械学習モデルは畳み込みです。覚えておいてください
    • この話を何度か投稿している気がしますが、もう少し詳しく説明してもらえますか。たとえば強化学習を畳み込みとして捉えるのは難しいように感じます