Stable Diffusion 3.5を純粋なPyTorchでゼロから再実装
(github.com/yousef-rafat)- miniDiffusion プロジェクトは、Stable Diffusion 3.5 モデルを PyTorch のみを用いてゼロから再実装したオープンソースである
- このプロジェクトの構成は、教育目的 と 実験、ハッキング 用途に重点を置いているのが特徴である
- コードベース全体は約 2800行 で、VAE から DiT、学習およびデータセットスクリプトまで最小限のコードで構成されている
- 主な構成要素として、VAE、CLIP、T5 テキストエンコーダ、マルチモーダル拡散トランスフォーマー、Joint Attention などがある
- まだ 実験的な機能 が含まれており、さらなるテストが必要な状態である
miniDiffusion プロジェクト紹介
miniDiffusion は、Stable Diffusion 3.5 の中核機能を PyTorch のみで再実装したオープンソースプロジェクトである
このプロジェクトは既存の Stable Diffusion 3.5 と比べて、次のような利点がある
- コードベースが約 2,800行 と小規模で、自分で構造を分析して学ぶのに非常に適している
- さまざまな 機械学習の実験 や モデルハッキング に有用に活用できる
- 依存関係が非常に少なく、最小限のライブラリのみを使用している
中核構造と構成ファイル
- dit.py : メインの Stable Diffusion モデル実装部
- dit_components.py : 埋め込み、正規化、パッチ埋め込み、DiT 補助関数の構成
- attention.py : Joint Attention(共同アテンション)アルゴリズムの実装部
- noise.py : Rectified Flow のための Euler ODE スケジューラを含む
- t5_encoder.py, clip.py : T5 および CLIP テキストエンコーダの実装
- tokenizer.py : Byte-Pair および Unigram トークナイザの実装
- metrics.py : FID(Fréchet inception distance)評価指標の実装
- common.py : 学習に必要な補助関数を提供
- common_ds.py : 画像を DiT 用学習データに変換する iterable データセットの実装
- model フォルダ : 学習後のモデルチェックポイントとログを保存
- encoders フォルダ : VAE、CLIP など個別モジュールのチェックポイントを保存
⚠️ 実験的機能とテストの必要性 miniDiffusion にはまだ実験的な機能が含まれており、さらなるテストが必要な状態である
主な機能ごとの詳細構成
Core Image Generation Modules
- VAE、CLIP、T5 テキストエンコーダ の実装
- Byte-Pair、Unigram トークナイザの実装
SD3 Components
- Multi-Modal Diffusion Transformer Model
- Flow-Matching Euler Scheduler の実装
- Logit-Normal Sampling
- Joint Attention アルゴリズムの導入
モデル学習および推論スクリプト
- SD3(Stable Diffusion 3.5)向けの 学習および推論スクリプト を提供
ライセンス
- MIT ライセンス で公開されており、教育および実験目的 で制作されている
このオープンソースプロジェクトの意義と利点
- Stable Diffusion 3.5 クラスの最新画像生成モデル構造を 純粋な PyTorch のみで直接学習・ハッキングできる
- コードが簡潔で独立性が高く、構造分析/モデルチューニング/新規アルゴリズム研究 に最適化されている
- 最新のマルチモーダル、トランスフォーマー、アテンション手法などを直接実習できる
- 商用プロジェクトとは切り離して安全に実験できる基盤を提供する
1件のコメント
Hacker Newsの意見
Fluxのリファレンス実装は本当にミニマルな構造なので、興味がある人なら一度見てみる価値がある
Flux GitHub
minRFプロジェクトはrectified flowを活用して、小さな拡散モデルを学習するときに簡単に始められるのが利点
minRF GitHub
Stable Diffusion 3.5のリファレンス実装もかなり簡潔に書かれていて、参考にしやすい
SD 3.5 GitHub
リファレンス実装はメンテナンスが行き届いておらず、バグが多いことがよくある
miniDiffusionプロジェクトがStable Diffusion 3.5モデルを使っているという意味なのか気になる
関連コード
学習データセットは非常に小さく、ファッション関連の画像だけを含んでいる
ファッションデータセット
そのデータセットは拡散モデルのファインチューニングを試してみるためのもの
純粋なPyTorchを使うことでNVIDIA以外のGPUでも性能上の利点があるのか、あるいはPyTorchがCUDAにあまりに最適化されていて他のGPUベンダーは競争できないのか気になる
PyTorchはApple Siliconでもかなりよく動く
AMDのような非NVIDIAデバイスでも、MLワークロードをVulkan経由で動かすことはできる
PyTorchのROCmサポートは非常にゆっくり進んでおり、動いたとしても速度が遅い
PyTorchはROCmでもきちんと動くが、完全に「同等」と言えるほどかはよく分からない
PyTorchコードで
の代わりに
のように試してみるとよいのでは、という提案
学習者にとって良い資料に見える
初心者でも追えるチュートリアルや解説があるのか気になる
fast.aiにはStable Diffusionを自分で実装してみる講義がある
Stable Diffusionをライセンス制限なしで使えるという意味なのか気になる
正直少し気恥ずかしいのだが、このリポジトリができる前と後で、私たちが新たに得たものは何なのか気になる
個人的にはモデルを作ることを避けてきて、主に成果物を横から見てきた立場
以前からすでにPyTorchベースの推論/学習スクリプトは公開されているものだと漠然と思っていた
少なくとも推論スクリプトはモデル配布時に一緒に付いてくるものだと思っていたし、ファインチューニング/学習スクリプトもあるものだと思っていた
このプロジェクトが「クリーンルーム」または「ダーティルーム」的に既存のものを書き直しただけなのか、それとも既存のPyTorchコードですらCUDA/Cベースで複雑すぎるため純粋なPyTorch版に大きな意味があるのか確信が持てない
ともかくよく分からないので、誰か説明してくれるとありがたい
このプロジェクトの核心的な価値は「依存関係を最小限にした実装」にある
Stability AIはStable DiffusionモデルをStability AI Community Licenseで配布しており、MITと違って「完全に自由」ではない
SD 3.5(あるいはどのバージョンでも)を考えるとき、自分としては学習過程で生成された重みの部分こそが核心だと認識している
Ludwig Maximilian UniversityのCompVizグループが公開したオリジナルの学術ソースが、実運用に使えるものなのか気になる
このDiffusion Transformer(DiT)の実装が、SD 3.5フル版のようにクロストークンアテンションをきちんと実装しているのか、それともコードの可読性のために単純化しているのか気になる