モデルをスケールさせる方法:TPU上のLLMをシステム視点で捉える
(jax-ml.github.io)- ディープラーニング性能を大規模に最適化することは「錬金術」のように見えるが、実際には理解可能なシンプルな原則によってモデル効率を高められる
- 単一アクセラレータから数万台のアクセラレータまで、比較的シンプルな原則があらゆる場面に適用され、これを理解することで次のような有用な作業が可能になる:
- モデルの各部分が理論上の最適値にどの程度近づいているかを大まかに把握
- さまざまなスケールで複数の並列化手法を選択する根拠を得られる
- 大規模Transformerモデルの学習と実行に必要なコストと時間を見積もる
- 特定ハードウェアの特性を活かすアルゴリズムを設計する
- 現在のアルゴリズム性能の限界を明確に理解したうえでハードウェアを設計する
- 必要な前提知識
- LLMとTransformerアーキテクチャの基本概念を理解していることが必要
- 大規模運用方式への理解は必須ではない
- LLM学習の基本知識とJAXの使用経験があればより望ましい
- Transformerアーキテクチャに関するブログ記事と、JAXのLLMスケーリングに関するスライドの参照を推奨
- 目標
- 与えられたハードウェア上でモデルをどのように並列化すればよいかを見積もれる力を身につけること
- 学習と推論にかかる時間とコストを大まかに計算できる能力を養うこと
なぜ関心を持つべきか
- 3〜4年前までは、ML研究者の大半がこのような大規模スケール最適化を深く理解する必要はなかった
- 現在では「小さな」モデルでさえハードウェア限界に近いところで動作するため、効率的な大規模処理の理解が不可欠になっている
- MLの歴史は、システム革新とソフトウェア改善が相互に発展してきた流れとして見ることができる
- 近年のTransformerモデルはハードウェア限界まで活用するようになっており、モデル効率を理解しなければ新しいアーキテクチャや研究が実運用で失敗する可能性が高い
- ベンチマークで20%の性能向上を得ても、ハードウェア効率が20%低下すれば、結局実用性は低くなる
- モデルスケーリングの核心的な目標は、チップ(アクセラレータ)の数を増やしたときにスループットが線形に増加するようにすること
- これを「強スケーリング」と呼ぶ
- チップを追加すると計算時間は短縮されるが、チップ間通信コストが発生する
- 通信が計算より長くかかると「通信律速(Communication Bound)」の状態となり、強スケーリングは不可能になる
- ハードウェアを十分に理解してこうしたボトルネックがどこで発生するかを予測できれば、それを避けるようにモデルを設計・再構成できる
- この本の目標は、TPU(およびGPU)ハードウェアがどのように動作し、Transformerアーキテクチャが現在のハードウェア上でうまく動くようにどのように発展してきたかを説明すること
- 新しいアーキテクチャを設計する研究者にも、現世代のLLMを高速に動かそうとするエンジニアにも役立つことを目指している
全体概要
- この文章は次のように構成されている
- セクション1では、roofline分析を通じてモデルの性能限界を決める要素(通信、計算、メモリ)を説明する
- セクション2、セクション3では、TPUとGPUの内部構造およびチップ間接続方式を扱う
- これにより、次のような疑問に答える
- 特定サイズの行列積は理論上どれほど速く実行できるか
- どの時点で計算がメモリ帯域幅や通信帯域幅に縛られるようになるか
- TPUクラスタはどのような構造で接続され、あるチップから別のチップへデータを移動するのにおよそどれくらいの時間がかかるか
- 分散された行列をどのように効率よく乗算できるか
- これにより、次のような疑問に答える
- セクション4では、Transformerアーキテクチャの数式(行列サイズ、パラメータ数、FLOPs)を詳しく扱う
- セクション5とセクション7が中核であり、複数チップにモデルを並列化するさまざまな方法を紹介する
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- ZeRO, Rematerialisation, Host offload, Gradient accumulation などのメモリ節約手法も扱う
- セクション6、セクション8では、LLaMA-3モデルをTPUで学習・推論する過程を例に、実際のコスト、時間、構成方法を示す
- 最後にセクション9、セクション10では、JAXでモデルをプロファイルし、デバッグし、並列処理を適用する実践的な方法を扱う
詳細内容:本の主要セクション要約
-
パート1: Preliminaries
-
- アルゴリズムを制約する3つの要素: 計算、通信、メモリ
- そこから演算速度の上限を見積もる方法を学ぶ
-
- TPUがどのように計算するか
- Systolic array構造とは何か
- TPUがメモリ帯域幅と通信帯域幅をどのように提供するかについての基本的理解
-
- モデルパラメータを複数チップに分割保存するSharding手法
- 分散行列演算時に発生する通信とボトルネックの扱い方
-
-
パート2: Transformers
-
- Transformerにおける行列積が具体的にどのような形になるか
- パラメータ数、FLOPs、KVキャッシュサイズなどの計算方法
- Attention演算がFeed-Forwardブロックに比べてどれほど多くの計算を要するかを把握する
-
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel 手法の紹介
- ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload などのメモリ節約策
- 特定モデルサイズとチップ数に合わせて並列化を構成する概念を整理する
-
- 実際のTPU環境でLLaMA 3モデルを学習すると仮定した場合の所要時間とコストの見積もり
- バッチサイズ、並列化方式、メモリ使用量などの具体例を提示
-
- 推論時には遅延(latency)が重要な新要素として現れる
- KVキャッシュなどによるメモリ使用と通信の問題
- モデルサービングのために複数チップをどう割り当て、どう接続するかの議論
-
- TPU v5eでLLaMA 3をサービングすると仮定した場合のおおよそのコスト、遅延、スループットのトレードオフ分析
-
-
パート3: Practical Tutorials
-
- JAX+XLAスタックの理解
- 実際の性能低下問題の把握と解決策
- JAX/TensorBoardプロファイラの使い方
-
- JAXの並列化API(primitives)の活用法
- 例題と課題を通じて並列演算の概念を学ぶ
-
- TPUとLLMに関する追加の読み物
- 全体を簡潔に締めくくり、今後の展望に言及する
-
1件のコメント
Hacker Newsのコメント