PyTorchは死んだ。JAX万歳
(neel04.github.io)- PyTorchが生産性の損失と開発時間の浪費を招く理由は、「フレームワーク自体が悪いからではなく、現在適用されているユースケースに合わせて設計されていないから」である。
PyTorchの哲学
- PyTorchの哲学は、動的で、デバッグしやすく、Pythonicであること
- 一方、TensorFlow 1.xはXLAコンパイラを強力に活用し、静的だが高性能なフレームワークになろうとしていた
- TensorFlowの開発者たちは、コミュニティが1.x APIを嫌っていることに気づき、Kerasをメインインターフェースとして使うことを決め、XLAコンパイラの役割を縮小した
- PyTorchは自らのルーツを守り、TensorFlowの静的で遅延的なアプローチとは異なり、
torch.Tensorが即時に評価される、より動的な「即時実行」アプローチを採用した - これが成果を上げ、多くの研究がPyTorchへ移行した
- 2021年にGPT-3が登場すると、性能とスケーラビリティが主要な関心事になった
- PyTorchはこうした需要にある程度うまく対応したが、その哲学を前提に設計されていなかったため、次第に技術的負債が蓄積し、基盤が揺らぎ始めた
- PyTorch開発者はどんな妥協点も望まず、2つの路線を同時に追求することを選んだ
- XLAコンパイラを性能と安定性に優れたデフォルトバックエンドとして使う
torch.compileスタックを構築し、必要に応じてユーザーがコンパイラを呼び出せる自由を与える
- 長期戦略の欠如は深刻な問題である
- PyTorchはJAXのようなコンパイラ中心の哲学にコミットしたがっていないが、良い代替案も見当たらない
- この問題に対する競合製品の解決策は何か?
JAXのコンパイラベース開発
- JAXはTensorFlowの強力なコンパイラスタックであるXLAを活用している
- XLAは強力なコンパイラだが、エンドユーザーに対してはすべて抽象化されている
- 関数がpureでありさえすれば、
@jax.jitデコレータを使って関数をJITコンパイルし、XLAで使えるようにできる - XLAは、生成されたグラフが正しいかの検証、JAXでシャーディングを用いた自動並列化を処理するGSPMDパーティショナ、グラフ最適化、演算子およびカーネル融合、レイテンシ隠蔽スケジューリング、非同期通信のオーバーラップ、
tritonのような別バックエンド向けコード生成などを、すべて裏側で処理する - JAXの制約に従いさえすれば、XLAが自動で処理してくれる
- たとえば並列化の際に
torch.distributed.barrier()のような通信プリミティブは不要である - DDP対応はシンプルなコードで可能だ
- XLAのアプローチは、計算がシャーディングに従うというものだ。したがって、入力配列がある軸に沿ってシャーディングされていれば、XLAが下位計算について自動で処理する
- 「コンパイラベース開発」というアイデアは、Rustコンパイラの動作と似ている
- PyTorchの限界
- PyTorch開発者が、柔軟性と自由という中核哲学を維持する代わりに、新機能のためにコンパイラスタックを統合し依存する選択をしたことに不満がある
- PyTorch 2.xの公式ロードマップによれば、XLAをTorchと完全統合する長期計画が明確に示されている
- これはひどいアイデアだ。RustコンパイラにC++コードを無理やり押し込む方が、Rust自体を使うより良い体験になると言っているようなものだ
- TorchはJAXと違って、XLAを中心に設計されていない
- もしPyTorchがXLAベースのコンパイラスタックを使うと決めたのなら、理想的なフレームワークはそれを中心に特化して設計・構築されたものではないのか?
- PyTorchが望むコンパイラバックエンドを選べる「マルチバックエンド」アプローチを追求したとしても、断片化の問題を悪化させ、すべてのコンパイラスタックの制約を尊重しようとしながらAPIを完全に壊してしまうのではないか?
- TPUでTorch/XLAを使ったことがある人なら誰でも、深刻なPTSDに苦しむ
Multi-Backendは失敗した
- PyTorchは一度にすべてをやろうとして、悲惨な失敗をした
- 「マルチバックエンド」という設計判断は、この問題を指数関数的に悪化させた
- 理論上は好きなスタックを選べるように聞こえるが、実際には理解しがたいトレースバックと非互換性問題が絡み合った混沌である
- バックエンド間の制約条件とPyTorch APIの衝突
- これらのバックエンドを動かすこと自体が難しいのではなく、これらのバックエンドが期待する制約条件が、PyTorchの柔軟でPythonicなAPIとうまく噛み合わない
- APIの一貫性を保つことと、バックエンドの制約に従うことのあいだにはトレードオフがある
- 結果として、開発者は単一バックエンドに実際に統合・コミットする代わりに、よりコード生成に依存しようとする
- PyTorchの戦略不在
- PyTorchは意味のあるトレードオフを拒否するため、あらゆる決定が妥協の産物のように感じられる
- 一貫性も、全体戦略もない
- 最終的にユーザーに大きなフラストレーションをもたらし、噛み合っていない機能の寄せ集めのように見える
- エコシステムを殺すのに、これ以上速い方法はない
- JAXのアプローチに従うべきでない理由
- PyTorchはJAXの「統合コンパイラ+バックエンド」アプローチに従うべきではない
- JAXはXLAとともに動くよう明示的に設計されているからだ
- PyTorchのフロントエンドをJAXのものに置き換えることが戦略になるはずはない
- XLAを前提としてJAXより優れたAPIを考案することは、事実上不可能である
- 開発者が新しく異なるアイデアを試すこと自体を非難するつもりはない
- しかしPyTorchが時の試練に耐えるには、理想的なチュートリアル条件の外ではすぐ崩れる派手な新機能を提供することより、基盤を強化することにもっと重点を置くべきだ
PyTorchの断片化とJAXの関数型プログラミング
- JAXの関数型API
- JAXの関数はpureでなければならない。つまり、グローバルな副作用があってはならない
- 数学関数のように、同じデータが与えられれば実行コンテキストに関係なく常に同じ出力を返さなければならない
- この設計哲学のおかげで、JAXの関数は構成可能で相互運用性にも優れる
- 開発の複雑さが減り、関数は特定のシグネチャと明確に定義された具体的な処理として定義される
- 型が守られていれば、関数は即座に動作することが保証される
- これは科学計算、特にディープラーニングで必要とされる作業の種類に適している
- optax APIの例
- 関数型アプローチのおかげで、optaxには「chain」というものがある
- これは勾配に順番に適用される複数の関数を含む
- 根本的な構成要素はGradientTransformationである
- 強力でありながら表現力の高いAPIを実現している
- たとえば勾配をクリッピングしたり、勾配のEMAを取ったり、オプティマイザを組み合わせたりすることが非常に簡単になる
- 関数型設計の利点
- 関数型設計のもう1つの優れた帰結が
vmapである - これは「vectorized map」を意味し、その機能を正確に表している
- すべてをmapでき、
vmapでありさえすればXLAが自動で融合・最適化してくれる - 関数を書くときにバッチ次元を意識する必要がない
- すべてのコードを
vmapすればよいだけだ - これは
ein-*系の操作があまり必要なくなることを意味する - 2D/3Dテンソル操作を把握するのがより直感的になり、可読性もはるかに良くなる
- 個々の構成要素を分離して推論するだけでよいため、うまく動く複雑なコードをより簡単に書ける
- pure性の制約を守り、正しいシグネチャさえあれば、構成可能性のような他の利点もすべて享受できる
- 関数型設計のもう1つの優れた帰結が
- PyTorchエコシステムの問題点
- torchでは、どのスタック(FSDP + マルチノード +
torch.compileなど)を使っていても、常に何かが壊れる可能性がある - 複数の要素が正しく連携する必要があり、どれか1つでも失敗すれば午前3時までデバッグする羽目になる
- PyTorchが提供する数十もの機能のすべての組み合わせをテストすることはできないため、開発中に見つからないバグが常に存在する
- 相当な努力なしに、うまく動作するコードを書くことは不可能である
- torchエコシステムは非常に肥大化し、バグも多くなっている
- 共有抽象化が存在しないため、他の「ソリューション」と連携するよう設計されていない新しいライブラリやフレームワークが登場する
- それはやがて依存関係と
requirements.txtの混乱へと急速に変質する - GitHub Issueやフォーラム議論の70〜80%は、単に異なるライブラリ間でエラーが発生することが原因である
- これを解決する方法はほとんどない
- torchでは、どのスタック(FSDP + マルチノード +
- 解決策の不在
- これはOOPと設計の問題である
- PyTreeのような基本的でPyTorchらしいオブジェクトが、抽象化の共通基盤を築く助けになったはずだと思う
- 関数型プログラミングのパラダイムを採用することもできない
- そうすれば、既存のtorchコードベースとの後方互換性をすべて壊しつつ、JAXの劣化版に収束してしまうだろう
- PyTorchはこの点で完全に壊れた状態に見える
JAXの再現性の優位
- シード処理
- PyTorchのシード処理は理想的ではない
- 一般には複数行のコードを実行する必要がある
- 簡単に忘れたり、設定を誤ったりしやすい
- JAXは明示的なキーを作成し、乱数が必要なすべての関数にそれを渡すことを強制する
- このアプローチは、RNGが常に静的にシードされるため、この問題を完全に解消する
- JAXには独自のNumPyである
jax.numpyがあるため、別途シードを設定する必要がない - こうした小さなQoL上の判断が、フレームワーク全体のユーザー体験を大きく改善しうる
- 移植性
- PyTorchコードベースを使うとき最大の問題の1つは、移植性の欠如である
- CUDA/GPU向けに書かれたコードベースは、TPU、NPU、AMD GPUなどの非Nvidiaハードウェア上ではうまく動作しない
- 1ノード向けに書かれたPyTorchコードをマルチノードへ移植するのは難しい
- マルチノード化には、しばしば数十時間の開発時間と大幅なコード変更が必要になる
- JAXのコンパイラ中心アプローチは、この点で優位性がある
- XLAはデバイスバックエンド間の切り替えを処理し、最小限のコード変更でGPU/TPU/マルチノード/マルチスライス上でうまく動作する
- ハードウェアベンダーが自社デバイスをサポートしやすくなり、デバイス間の切り替えも容易になる
- すべての人が同じハードウェアにアクセスできるわけではないので、さまざまな種類のハードウェアで移植可能なコードベースは、ディープラーニングを初学者・中級者にとってよりアクセスしやすくする小さな一歩になりうる
- 自動スケーリング
- それ自体でうまく自動スケーリングできるコードベースは、再現に非常に役立つ
- 理想的には、最小限のコード変更で、ネットワーク境界に縛られず自動的に起こるべきである
- JAXはこれをうまく実現している
- JAXコードを書くとき、通信プリミティブを指定したり、
torch.distributed.barrier()を至る所に置いたりする必要はない - XLAは利用可能なハードウェアを考慮して、自動的にそれらを挿入する
- JAXが検出できるすべてのデバイスは、ネットワーク、トポロジー、構成などに関係なく自動的に利用される
- 計算を自動で同期・準備し、カーネルの非同期実行を最大化してレイテンシを最小化する最適化パスを適用する
- 人間が行う必要があるのは、入力配列のバッチ次元のように、デバイスへ分散したいテンソルのシャーディングを指定することだけである
- XLAの「計算はシャーディングに従う」アプローチにより、残りは自動で解決される
- スケールに合わせて検証された実験を、趣味レベルでも簡単に実行・検証し、場合によっては再現できる
- これは忘れられたアイデアの再発見を容易にし、最小限の労力でより大きな規模で簡単に試験できるため、そのような実験を促進しうる
JAXの欠点
- ガバナンス構造
- 現在、XLAはTensorFlowのガバナンス下にある
- PyTorchのような独立した組織体を設立する議論はあったが、具体的な取り組みはあまり進んでいない
- Googleには人気のない製品を打ち切るという評判があるため、Googleへの信頼は高くない
- JAXは技術的にはDeepMindのプロジェクトであり、Google全体のAI推進において中核的意味を持つが、独立したガバナンスはエコシステム全体に長期的な大きな利益をもたらすように思える
- 別個のガバナンス組織があれば、プロジェクト開発に指針を与えられるだろう
- それにより具体的な構造が提供され、Googleの悪名高い官僚主義から切り離されることで、多くの問題を一度に回避できる
- JAXに必ずしもこの種の公式構造が必要というわけではないが、Google上層部の判断に関係なくJAX開発が長く続くという保証があれば望ましい
- それは、いつか保守されなくなるかもしれないツールの統合にリソースを投じることをためらう企業や大規模研究所での採用を、明らかに後押しするはずだ
- XLAのオープンソース化
- 長いあいだXLAはクローズドソースのプロジェクトだった
- しかし、それをオープンソースにする努力が進められ、現在のOpenXLAは内部XLAビルドよりはるかに優れた性能を示している
- それでもXLA内部に関する文書は依然として不足している
- リソースの大半はライブトークや時折出る論文に限られ、しかもしばしば古い
- 予定機能に関する公開ロードマップがあれば、人々は進捗を追跡しやすくなり、特に興味深い部分への貢献もしやすくなるだろう
- XLAコンパイラスタックの各段階を分析し、詳細を説明するEdward Yangスタイルの短いブログ記事があれば、XLAに何ができて何ができないのかを実務者がよりよく評価する助けになるだろう
- それがリソース集約的であり、他により有効な使い道があることは理解しているが、人はツールを理解するとより信頼するものであり、エコシステム全体にわたって前向きな波及効果があり、皆の利益になると考える
- エコシステム統合
flaxはJAXエコシステムの悩みの種である- 直感的でないAPI、簡潔すぎる構文を持ち、PyTorchから移行する初心者にとっては文字通り地獄である
equinoxを使うのがよいflaxの欠点を解消しようとする開発チームの試みもあったが、最終的には時間の無駄であるequinoxスタイルのAPIが欲しいなら、equinoxを使えばよいflaxに特別に優れている点はそれほど多くなく、equinoxで再現するのも難しくない- 現在のJAXエコシステムの多くは
flax中心に設計されている equinoxは本質的にPyTreeとやり取りするため、すべてのライブラリと相互運用可能だが、少しのeqx.partitionとfilterは必要になる- この現状を変えたい。
equinoxがあらゆる場所で第一級のサポートを受けるべきだ - これは賛否のある意見だが、典型的なサンクコストの誤謬である
equinoxは、JAXフレームワークが本来あるべき姿により近い形でうまく機能するequinoxのドキュメントで要約されているようにequinoxとflaxを比較すると、equinoxのほうが優れている- JAXエコシステムの管理者たちが
equinoxの人気を認識し、それに応じて調整しているのはよいことだが、Googleやflaxチームにも公式にもっと支援を示してほしい - JAXを試してみたいなら、
equinoxを使うのがよい
- 鋭い角
- API設計上の判断とXLAの制約により、JAXには注意すべき「鋭い角」がある
- それについてはよく書かれたドキュメントで非常に簡潔に説明されている
- JAXを使う前に少なくとも一度は読んでおくのがよい
- RTFMすることが、いつもそうであるように、多くの時間と労力を節約してくれる
結論
- このブログ記事は、PyTorchが実際の研究ワークロード、特にGPUに最適だという繰り返し語られる神話を正すためのものだった。もはやそうではない
- 実際、この分野全体にとってPyTorchコードをすべてJAXへ移植することが非常に有益だと主張できるほど極端ですらある
- 自動並列化、再現性、クリーンな関数型APIなどは些細な機能ではなく、多くの研究コードベースに大きく役立つはずだ
- この分野を少しでも良くしたいなら、コードベースをJAXで書き直すことを検討してみてほしい
8件のコメント
時代は流れ続けます。
2022年のPyTorchとTensorFlowの比較
torchとonnxで耐えます
学部生が書いた文章だ…すごい
PyTorchはHuggingfaceがなかったら、マジでww
JAX万歳! 最近使ってみたのですが、NNX APIがとても気に入りました。
JAXの最大の問題は、それがGoogleだという点。Googleはオープンソースを捨てることでかなり有名で(Tflite、android things、dart、angular、bazel など)、tensorflowもある時期からアップデートがあまり行われなくなり始めた。一方で torch は巨大なオープンソースを運営する Facebook から始まり、非常によく運営されていて、すでに torch 財団によって運営されている。torch の欠点が確かに当たっている部分はあるが、そのオープンソースを誰が持続可能に運営するのかという点において、JAXはすでに大きなリスクを抱えて始まっているように思う。
少なくとも Dart は Flutter によって、しばらくの間はうまく生き残れそうですね。
フェイスブックはReactやDjangoなど、少なくとも自分たちが使っている技術スタックについては義理堅く(?)継続的に貢献しているように見えますが、グーグルは少しでも旧式になると使い古したぞうきんのように捨てる気がします……