2 ポイント 投稿者 GN⁺ 2024-09-24 | 1件のコメント | WhatsAppで共有

Felafax BlogTune Llama3 405B on AMD MI300x (私たちの道のり)

紹介

  • オープンソースモデルの大規模化に伴い、大規模なAIトレーニングを処理できる強力なインフラの必要性が高まっている
  • FelafaxはAMD GPU上でLLaMA 3.1 405Bモデルをファインチューニングし、AMDハードウェアの効率性を実証した
  • すべての作業をGitHubでオープンソースとして公開した
  • AMD MI300X GPUはNVIDIAのAIハードウェアと比べて高い性能を提供する
  • TensorWaveの支援によりこのプロジェクトが実現した

JAXとは何か、そしてなぜ選んだのか

  • JAXはNumPyに似たAPI、自動微分、GoogleのXLAコンパイラを組み合わせた強力な機械学習ライブラリである
  • モデル並列処理のための優れたAPIを提供しており、大規模モデルの学習に最適である

JAXの利点

  • 純粋関数: JAXは純粋関数を書くことを推奨しており、コードの構成、デバッグ、可読性が向上する
  • 高度な並列処理: JAXの柔軟なJIT APIは、大規模学習に不可欠な高度なデータ並列およびモデル並列をサポートする
  • クリーンなコードベース: JAXの設計思想は、ハードウェアプラットフォーム間で移植可能なコードを書くことを促す

JAXが非NVIDIAハードウェアで優れている理由

  • ハードウェア非依存のアプローチ: JAXはXLAコンパイラを活用し、計算をハードウェアに依存しない中間表現へコンパイルする
  • プラットフォーム非依存の最適化: XLAコンパイラはハードウェアとは独立して最適化を行う
  • 容易な移植性: JAXを使えば、NVIDIAからAMDへ移行する際のコード変更を最小限に抑えられる

AMD GPUでのJAXセットアップ

  • Dockerイメージを取得してコンテナを起動し、インストールを確認した
  • AMD MI300x GPU 8基を使用してLLaMA 405Bモデルを学習した

LLaMA 405Bの学習: 性能とスケーラビリティ

  • JAXを使ってAMD GPU上でLLaMA 405Bモデルを学習した
  • LoRAファインチューニングにより、モデル重みとLoRAパラメータをbfloat16精度で調整した
  • モデルサイズ: 約800GBのVRAMを使用
  • LoRA重みとオプティマイザ状態: 約400GBのVRAMを使用
  • 総VRAM使用量: 約1200GB
  • 学習速度: 1秒あたり約35トークン
  • メモリ効率: 約70%を維持
  • スケーラビリティ: JAXにより8基のGPUでほぼ線形にスケールした

私たちの学習設定

  • LLaMA 3.1をPyTorchからJAXへ変換した
  • モデルのロードとパラメータのシャーディングによって効率的に分散した

JAXでのパラメータシャーディング

  • JAXのデバイスメッシュ機能を使用して、8基のAMD GPUにモデルを効率的に分散した
  • パラメータシャーディングのルールを定義し、各テンソルの次元をメッシュ軸に沿ってシャーディングした

LoRA学習の実装

  • LoRAは重み更新を低ランク行列に分解することで、学習可能なパラメータ数を削減する
  • LoRADenseレイヤーを実装してLoRAパラメータを含めた
  • LoRAパラメータを効率的に分散し、メモリ使用量と計算効率を最適化した

結論

  • AMD GPUとJAXを用いたLLaMA 3.1 405Bモデルのファインチューニング体験は非常に良好だった
  • JAXの強力な並列処理機能とハードウェア非依存のアプローチを活用して、モデルを効率的に分散した
  • AMD GPUが大規模AIトレーニングの強力な代替手段であることを実証した
  • GitHubリポジトリで全コードを確認し、自分で実行できる

GN⁺のまとめ

  • この記事は、AMD GPUとJAXを使って大規模AIモデルを効率的に学習する方法を説明している
  • AMDハードウェアがNVIDIAに対するコスト効率の高い代替手段であることを強調している
  • JAXのハードウェア非依存アプローチがコードの移植性を高め、保守を容易にする
  • 大規模モデルの学習に関心のある人にとって有用な情報と実践的なコードを提供している
  • 類似機能を持つプロジェクトとして、NVIDIAのCUDAやPyTorchがある

1件のコメント

 
GN⁺ 2024-09-24
Hacker Newsの意見
  • JAXを使って、Llama 3.1 405Bモデルを8台のAMD MI300x GPUでファインチューニングした成果を共有

    • JAXの高度なシャーディングAPIのおかげで、優れた性能を達成
    • ブログ記事とオープンソースコードへのリンクを提供: GitHubリンク
    • NVIDIAハードウェアではなく、TPU、AMD、TrainiumでLLMをファインチューニングし、提供するAIインフラを構築するスタートアップ
    • 多くの企業がAMD GPUでPyTorchを動かそうとしているが、それは困難な道だと考えている
    • PyTorchはNVIDIAエコシステムと深く結びついており、非NVIDIAハードウェアで動かすには多くの修正が必要
    • JAXのほうが非NVIDIAハードウェアに適していると考えている
    • JAXでは、MLモデルのコードはハードウェア非依存のHLOグラフにコンパイルされ、XLAコンパイラがハードウェア固有の最適化を行う
    • 同じJAXコードをGoogle TPUとAMD GPUで変更なしに実行可能
    • 会社の戦略は、JAXへモデルを移植し、XLAカーネルを活用して非NVIDIAバックエンドから最大性能を引き出すこと
    • Llama 3.1をPyTorchからJAXへ最初に移植し、今では同じJAXモデルがTPUとAMD GPUでうまく動作している
    • ビジョンとリポジトリについて意見を聞きたいとのこと
  • メモリ制約を克服し、JITコンパイル版を実行する方法を探る提案

    • さらなる性能向上につながる可能性がある
  • AMD GPUとROCmサポートに関する経験の共有

    • 1年前にAMD GPUとROCmサポートを試したが、AMDがNVIDIAに追いつくにはまだ遠いと感じた
    • JAXを選んだのは興味深いアプローチだが、PyTorchから離れる際にどんな困難があったのか気になる
  • 405Bモデルの推論面で実験した経験の共有

    • torch.cudaはそこまで悪くないと考えている
    • AMD版のPyTorchがこれを変換してくれるため、名前の問題にすぎないと見ている
    • rocm:pytorchコンテナを使うのは、rocm:jaxコンテナを使うのと同じくらい簡単
    • 性能データがあまり公開されていない点を指摘
    • MFU(モデル利用率)の数値が気になる
  • 性能データが不足していることへの質問

    • AMD GPUの大量発注によって価値を引き出せる可能性があるのか疑問を呈する
    • 「ノー」という印象を受ける
  • Obsidian(ノートアプリ)がなぜこれをしているのかという疑問

    • 最初はObsidianの投稿だと思った
    • GitHub.comとGitHub.ioをいまだに区別していない理由を不思議に思う
  • @dangに、URLへユーザー名を含めるよう要望

    • この投稿はObsidian自体ではなく、ユーザー生成ブログに関するもの**