20 ポイント 投稿者 GN⁺ 2024-05-17 | 1件のコメント | WhatsAppで共有
  • Llama 3モデルの実際に動作する実装を通じて、正確な構造を理解する

概要

  • Metaが公開したLlama 3モデルが注目を集めている。
  • 24K GPUs、15Tの学習データ、10Mの命令データ、1.3M GPU時間など、圧倒的なスケールと性能を誇る。
  • モデル構造は大きく変わっていない。Llama 3はGQAを使用するが、これはLlama 2 70Bでもすでに実装されていた。
  • NumPyだけを使って、モデル構造を直感的に理解できるよう実装している。
  • Andrej KarpathyがLlama 2構造で学習したstories15MモデルをNumPy圧縮形式に変換して使用している。

構造

  • Llama 3モデルの構造は42dot LLMと同一。
  • モデルパラメータ:
    • dim: 288
    • n_layers: 6
    • n_heads: 6
    • vocab_size: 32000
    • max_seq_len: 256
    • max_new_tokens: 50

RoPE #1

  • RoPE埋め込みのためにcosとsinを事前計算する。
  • これらの値はQKに使われる。
  • 計算結果はnp.outerで掛け合わせられ、cosとsinが計算される。

RMSNorm

  • RMSNormは、従来のMini BatchやLayerの統計の代わりに、活性化値をRoot Mean Squareで正規化する。
  • 一貫した活性化スケーリングを提供する。

QKV

  • QKV計算は、GPTのように1つの重みをmatmulしてから分割する方式とは異なり、LlamaはQKVそれぞれに対する重みを持つ。
  • Multi-Head Attentionのために各値を再構成する。

RoPE #2

  • RoPEは絶対位置エンコーディングと相対位置エンコーディングの両方の特性を持つ。
  • QKにのみ適用され、入力を分割してcosとsinで掛けた後、結果を足し引きして再構成する。

KVキャッシュ

  • GPTスタイルの生成モデルはMasked Attentionを使用するため、KVキャッシュが可能。
  • 以前の結果は常に同一なので、KとVをキャッシュし、Qは最後の値だけを計算する。

GQA(Grouped-Query Attention)

  • GQAはLlama 2で導入された技術で、メモリ節約と性能向上をもたらす。
  • Llama 3では8B以上のすべてのモデルにGQAが適用される。

Scaled Dot-Product Attention

  • Multi-Head Attentionによって、それぞれのAttentionを計算する。
  • 結果はsoftmaxとmatmulで得られる。

Feed Forward

  • LlamaモデルのFeed Forwardは3つの線形層を使い、biasはない。
  • swish値を生成し、x_Vと掛け合わせた後、再びダウンスケーリングする。

SwiGLU

  • SwiGLUは複数のフィードフォワード層を独特に組み合わせ、モデル性能を向上させる。

Linear

  • 最終出力は最後のlogitだけをmatmulで計算して速度を高める。

生成

  • 抽出したlogitを使ってトークンを1つずつ生成する。
  • Prefill PhaseとDecode Phaseに分かれる。
  • Prefill Phaseではすべての入力を渡し、Decode Phaseでは最後のトークンIDだけを渡して結果を得る。

  • 次のように実行できる:
    $ python llama3.py "I have a dream"  
    

GitHub

参考文献

  1. Exploring and Building the Llama 3 Architecture
  2. Rotation Matrix
  3. Mastering LLM Techniques: Inference Optimization
  4. arXiv:2305.13245

GN⁺の意見

  • Llama 3モデルの構造と性能: Llama 3モデルは既存のLlama 2モデルの構造を維持しつつ、性能を大きく向上させている。これはモデルの拡張性と効率性を同時に考慮した結果である。
  • NumPyで実装した理由: NumPyを使ってモデルを実装することで、モデルの構造と動作をより直感的に理解できる。これは学習者や研究者にとって大きな助けになる。
  • GQAの導入: GQAはメモリ節約と性能向上を同時に実現する技術であり、Llama 3ですべてのモデルに適用されることでモデルの効率性を最大化している。
  • KVキャッシュの重要性: KVキャッシュはGPTスタイルの生成モデルで重要な役割を果たし、これによってモデルの計算効率を大きく高められる。
  • 実際の利用例: サンプルコードを通じてモデルを実際に動かしてみることができ、これはモデルの性能を直接確認できる良い機会となる。

1件のコメント

 
xguru 2024-05-17

Hacker News に投稿されたものは英語ですが、原著者の Likejazz さんが韓国語で作成しておいたリンクに差し替えました。