FlashAttention-3: 非同期性と低精度で高速かつ高精度なAttention
(together.ai)-
Attentionの重要性
- AttentionはTransformerアーキテクチャの中核レイヤーであり、大規模言語モデルや長文コンテキストのアプリケーションでボトルネックを引き起こす。
- FlashAttentionとFlashAttention-2は、GPU上でメモリの読み書きを最小化することでAttentionを高速化するアプローチを切り開いた。
- これにより、LLMのコンテキスト長は大幅に拡大した。
-
FlashAttention-3の主要技術
- 非同期性の活用: Tensor CoresとTMAの非同期性を活用し、計算全体とデータ移動を重ね合わせる。
- ブロック単位演算: ブロック単位の行列乗算とsoftmax演算を交互に実行する。
- 低精度処理: FP8の低精度サポートを活用して性能を向上させる。
-
FlashAttention-3の性能向上
- GPU活用効率: H100 GPUの最大性能を75%まで活用し、前バージョンより1.5〜2倍高速。
- 低精度性能: FP8を使用して処理速度を高め、メモリ使用量を削減する。
- 長文コンテキスト処理: Attentionメカニズムを高速化し、より長いテキストを効率よく処理可能。
-
FlashAttentionの要約
- FlashAttentionは、Attention計算を再配置し、タイル化と再計算を活用することで速度を大幅に高め、メモリ使用量を削減する。
- タイル化によって入力ブロックを読み込み、そのブロックに対してAttentionを実行した後に出力を更新する。
- 中間のAttention行列をメモリに書き込まないことで、メモリの読み書き量を削減する。
-
Hopper GPUの新しいハードウェア機能
- WGMMA: 新しいTensor Coresを活用して高いスループットを提供する。
- TMA: グローバルメモリと共有メモリ間のデータ転送を高速化するハードウェアユニット。
- FP8低精度: FP8を使用してTensor Coreのスループットを2倍に高める。
-
非同期性: GEMMとSoftmaxの重ね合わせ
- 重ね合わせの必要性: GEMMとsoftmaxを並列に実行して性能を最大化する。
- ピンポンスケジューリング: 2つのwarp groupが交互にGEMMとsoftmaxを実行して性能を向上させる。
- warp group内での重ね合わせ: 同一のwarp group内でGEMMとsoftmaxを並列実行し、スループットを高める。
-
低精度: 非一貫処理による量子化誤差の低減
- 非一貫処理: Hadamard変換を用いて量子化誤差を減らす。
- 実験結果: 非一貫処理により量子化誤差を2.6倍低減した。
-
Attentionベンチマーク
- FP16: FlashAttention-2より約1.6〜1.8倍高速。
- FP8: 最大1.2 PFLOPSに到達。
GN⁺のまとめ
- FlashAttention-3は、GPUの新しいハードウェア機能を活用してAttentionメカニズムの性能を大幅に向上させる。
- 長文コンテキストを効率よく処理できるため、大規模言語モデルの性能を最大化する。
- PyTorchのような主要フレームワークに統合される可能性が高く、今後のAI研究と応用に大きな影響を与える見込み。
- 類似機能を提供するプロジェクトとして、TritonとcuDNNがある。
1件のコメント
Hacker Newsのコメント
Tri Dao は 2022年4月から FA3 の作業を始めていたようだ
Flash Attention アルゴリズムがハードウェアにどれほど依存しているのか気になる
コンパイラが FlashAttention のような最適化を自力で見つけられるのか気になる
ROCm/AMD MI300x への移植を望む人は連絡してほしいとのこと
TMA (Tensor Memory Accelerator) はグローバルメモリと共有メモリの間のデータ転送を高速化するハードウェアユニットである
FlashAttention-3 は Hopper GPU(例: H100)向けに最適化されている
現代の LLM では sigmoid のような活性化関数が非常に遅いと言われている
可変マスキングがない場合と比べて、ある場合に Flash Attention が5倍遅い理由が気になる
FlashAttention が LLM の attention 演算を置き換えられるのか気になる
高価なハードウェアが必要