llm.cを使って、GPT-2 (124M) モデルを90分で20ドルかけて再現する方法を説明
- GPT-2 (124M) は、OpenAIが2019年に発表した最小のモデル
- Lambdaで8X A100 80GB SXMノードを使うと、1時間あたり約14ドル、総コストは約20ドル
- 単一GPUでも学習可能だが、より時間がかかる(4〜24時間)
結果比較
- FineWeb検証データセットで、OpenAIが公開したチェックポイントより良い性能を示す
- ただし、GPT-2はWebTextで学習しているため、完全に公平な比較ではない
- HellaSwag精度も測定しており、GPT-3 Small (124M) の33.7に近い29.9を達成
- GPT-2 (124M) の29.4はすでに上回っている
- ただし、ここでは10Bトークンで学習しており、GPT-3は300Bトークンで学習している
最小環境設定
- GPUが必要(Lambda Labs推奨)
- Linux x86 64bit Ubuntu 22.04 with CUDA 12基準のガイド
- minicondaをインストールした後、PyTorch nightly版をインストール(任意)
- tokenizerのために必要なパッケージをインストール
- 速度向上のためにcuDNNをインストール(任意)
- 複数GPUを使う場合はMPIをインストール(任意)
- FineWeb 10Bトークンデータセットを前処理(約1時間所要)
- llm.cをコンパイル(混合精度、cuDNN FlashAttention使用)
学習実行
- 単一GPU使用時のサンプルコマンド
- マルチGPU(8基)使用時はmpirunで実行
- 主な引数の説明
- -i, -j : 学習/検証データのパス
- -o : ログ、チェックポイント保存パス
- -e : モデル初期化(depth 12 GPT-2)
- -b : マイクロバッチサイズ(メモリ不足時は減らす)
- -t : 最大シーケンス長
- -d : 総バッチサイズ(GPT-3論文参照)
- -r : Recompute設定(メモリ節約)
- -z : ZeRO-1(オプティマイザ状態のsharding)
- その他、weight decay、学習率、チェックポイント周期などを設定
学習過程
- 10B学習トークン、0.5Mバッチサイズ基準で約20Kステップを想定
- A100 40GB PCIe GPU基準で、ステップごとの所要時間、MFU、トークン処理量が出力される
- 学習初期にgradient exploding現象があるが、clippingで解決
可視化
- ログファイルをパースして学習曲線を可視化するJupyterノートブックを提供
Tokenizer
- 整数トークンを文字列に変換するために必要
- PyTorchスクリプトで生成可能
Sampling
- 現時点では推論に最適化されていない
- コードを少し修正することでunconditional/conditional samplingが可能
コード構成
train_gpt2.cuファイルに実装の大部分が含まれている
- 最初の500行はMPI、NCCL、cuDNN、cuBLASなどの設定
- 続く1500行はTransformerのforward/backward
- その次の1000行はGPT-2モデル実装
- 最後の1000行は学習ループ、引数パースなど
350Mモデル
- 10Bトークンでは不足しており、30Bトークンを使用
- 8X A100 80GBで14時間かかり、コストは約200ドル
FAQ
- サンプリング可能か: 可能だが非効率的。
- チャット可能か: 現時点では事前学習のみ可能で、チャット向けのファインチューニングは不可。
- マルチノード分散学習: 可能だが、まだテストされていない。
- ビット単位で決定論的か: ほぼ決定論的だが、一部カーネルのパッチが必要。
- FP8学習は可能か: 現在はBF16で学習しており、FP8はまもなく対応予定。
- NVIDIA以外のGPU対応: 現時点ではC/CUDAのみ対応。
GN⁺の意見
- GPT-2は現代的なLLMの出発点として非常に重要なモデルであり、その後のGPT-3や他のLLMもGPT-2と大きくは変わらない。
- このプロジェクトは、GPT-2クラスのモデルを誰でも妥当なコストで直接学習させられるようにしてくれる。LLMへの理解を深めるのに大いに役立ちそうだ。
- ただし、まだ推論に最適化されていないため、実サービスで活用するには制約がある。対話モデルとしてファインチューニングすることもサポートされていない。
- 現在はNVIDIA GPUのみをサポートしているが、今後はAMDやApple Siliconなど多様なプラットフォーム対応が期待される。
- 類似目的のオープンソースプロジェクトとしては、Megatron-LM、DeepSpeed、FairSeqなどがある。それぞれ長所と短所があるため、用途に応じて選ぶとよい。
- LLM開発エコシステム活性化の観点から、非常に心強いプロジェクトだ。今後がさらに楽しみだ。
まだコメントはありません。