Felafax BlogTune Llama3 405B on AMD MI300x (私たちの道のり)
紹介
- オープンソースモデルの大規模化に伴い、大規模なAIトレーニングを処理できる強力なインフラの必要性が高まっている
- FelafaxはAMD GPU上でLLaMA 3.1 405Bモデルをファインチューニングし、AMDハードウェアの効率性を実証した
- すべての作業をGitHubでオープンソースとして公開した
- AMD MI300X GPUはNVIDIAのAIハードウェアと比べて高い性能を提供する
- TensorWaveの支援によりこのプロジェクトが実現した
JAXとは何か、そしてなぜ選んだのか
- JAXはNumPyに似たAPI、自動微分、GoogleのXLAコンパイラを組み合わせた強力な機械学習ライブラリである
- モデル並列処理のための優れたAPIを提供しており、大規模モデルの学習に最適である
JAXの利点
- 純粋関数: JAXは純粋関数を書くことを推奨しており、コードの構成、デバッグ、可読性が向上する
- 高度な並列処理: JAXの柔軟なJIT APIは、大規模学習に不可欠な高度なデータ並列およびモデル並列をサポートする
- クリーンなコードベース: JAXの設計思想は、ハードウェアプラットフォーム間で移植可能なコードを書くことを促す
JAXが非NVIDIAハードウェアで優れている理由
- ハードウェア非依存のアプローチ: JAXはXLAコンパイラを活用し、計算をハードウェアに依存しない中間表現へコンパイルする
- プラットフォーム非依存の最適化: XLAコンパイラはハードウェアとは独立して最適化を行う
- 容易な移植性: JAXを使えば、NVIDIAからAMDへ移行する際のコード変更を最小限に抑えられる
AMD GPUでのJAXセットアップ
- Dockerイメージを取得してコンテナを起動し、インストールを確認した
- AMD MI300x GPU 8基を使用してLLaMA 405Bモデルを学習した
LLaMA 405Bの学習: 性能とスケーラビリティ
- JAXを使ってAMD GPU上でLLaMA 405Bモデルを学習した
- LoRAファインチューニングにより、モデル重みとLoRAパラメータをbfloat16精度で調整した
- モデルサイズ: 約800GBのVRAMを使用
- LoRA重みとオプティマイザ状態: 約400GBのVRAMを使用
- 総VRAM使用量: 約1200GB
- 学習速度: 1秒あたり約35トークン
- メモリ効率: 約70%を維持
- スケーラビリティ: JAXにより8基のGPUでほぼ線形にスケールした
私たちの学習設定
- LLaMA 3.1をPyTorchからJAXへ変換した
- モデルのロードとパラメータのシャーディングによって効率的に分散した
JAXでのパラメータシャーディング
- JAXのデバイスメッシュ機能を使用して、8基のAMD GPUにモデルを効率的に分散した
- パラメータシャーディングのルールを定義し、各テンソルの次元をメッシュ軸に沿ってシャーディングした
LoRA学習の実装
- LoRAは重み更新を低ランク行列に分解することで、学習可能なパラメータ数を削減する
- LoRADenseレイヤーを実装してLoRAパラメータを含めた
- LoRAパラメータを効率的に分散し、メモリ使用量と計算効率を最適化した
結論
- AMD GPUとJAXを用いたLLaMA 3.1 405Bモデルのファインチューニング体験は非常に良好だった
- JAXの強力な並列処理機能とハードウェア非依存のアプローチを活用して、モデルを効率的に分散した
- AMD GPUが大規模AIトレーニングの強力な代替手段であることを実証した
- GitHubリポジトリで全コードを確認し、自分で実行できる
GN⁺のまとめ
- この記事は、AMD GPUとJAXを使って大規模AIモデルを効率的に学習する方法を説明している
- AMDハードウェアがNVIDIAに対するコスト効率の高い代替手段であることを強調している
- JAXのハードウェア非依存アプローチがコードの移植性を高め、保守を容易にする
- 大規模モデルの学習に関心のある人にとって有用な情報と実践的なコードを提供している
- 類似機能を持つプロジェクトとして、NVIDIAのCUDAやPyTorchがある
1件のコメント
Hacker Newsの意見
JAXを使って、Llama 3.1 405Bモデルを8台のAMD MI300x GPUでファインチューニングした成果を共有
メモリ制約を克服し、JITコンパイル版を実行する方法を探る提案
AMD GPUとROCmサポートに関する経験の共有
405Bモデルの推論面で実験した経験の共有
torch.cudaはそこまで悪くないと考えているrocm:pytorchコンテナを使うのは、rocm:jaxコンテナを使うのと同じくらい簡単性能データが不足していることへの質問
Obsidian(ノートアプリ)がなぜこれをしているのかという疑問
@dangに、URLへユーザー名を含めるよう要望