DeepGEMM: 細粒度スケーリングによるクリーンで効率的な FP8 GEMM カーネル
(github.com/deepseek-ai)DeepGEMM
DeepGEMM は、FP8 一般行列積(GEMM)のためのライブラリで、DeepSeek-V3 で提案された細粒度スケーリングをサポートします。このライブラリは、通常の GEMM と Mix-of-Experts(MoE)のグループ化 GEMM をサポートし、CUDA で記述されているため、インストール時のコンパイルは不要です。NVIDIA Hopper Tensor Core をサポートし、FP8 Tensor Core 累積の不正確さを解消するために、CUDA Core による 2 段階累積を使用します。CUTLASS と CuTe の概念を一部活用しつつ、テンプレートや代数への依存を最小限に抑えてシンプルさを維持しています。約 300 行のコードからなる 1 つの中核カーネル関数で構成されており、Hopper FP8 行列積と最適化技法を学ぶのに適したリソースです。軽量な設計でありながら、さまざまな行列形状において、専門家が調整したライブラリと同等以上の性能を発揮します.
パフォーマンス
DeepSeek-V3/R1 推論で使用可能なすべての形状を、H800 SXM5 上で NVCC 12.8 を用いてテストしています。すべての高速化指標は、CUTLASS 3.6 をベースに内部で最適化した実装との比較により算出されています。一部の形状では性能が低い場合があり、最適化 PR を歓迎します。
通常の GEMM(高密度モデル)
- さまざまな行列サイズで DeepGEMM の性能を測定した結果、特定のサイズでは最大 2.7 倍の高速化を示しました。
MoE モデル向けグループ化 GEMM(連続レイアウト)
- グループ数と各グループの行列サイズに応じて、最大 1.2 倍の高速化を示します。
MoE モデル向けグループ化 GEMM(マスクレイアウト)
- マスクレイアウトを用いて最大 1.2 倍の高速化を示します。
クイックスタート
要件
- Hopper アーキテクチャ GPU、
sm_90aサポートが必要 - Python 3.8 以上
- CUDA 12.3 以上(最高の性能のため 12.8 以上を推奨)
- PyTorch 2.1 以上
- CUTLASS 3.6 以上
開発
- サブモジュールのクローン、シンボリックリンクの作成、JIT コンパイル、およびすべての GEMM 実装のテストを含む開発手順を説明しています。
インストール
deep_gemmを Python プロジェクトにインポートして使用できます。
インターフェース
注意事項
- このライブラリには GEMM カーネルのみが含まれており、NT 形式のみをサポートします。転置やその他の FP8 キャスト処理は独立して実装する必要があります。
通常の高密度 GEMM(非グループ化)
- 基本的な非グループ化 FP8 GEMM を実行するための関数を提供します。
グループ化 GEMM(連続レイアウト)
- MoE モデルにおいて、エキスパートが同一の形状を共有するシナリオ向けに設計されています。
グループ化 GEMM(マスクレイアウト)
- 推論のデコード段階でマスクテンソルを与え、有効な部分のみを計算します。
ユーティリティ
- 各種ユーティリティ関数と環境変数を提供し、性能最適化に役立ちます。
最適化
持続的なワープ特化
- CUTLASS の設計に従い、データ移動、Tensor Core MMA 命令、CUDA Core への昇格を重ね合わせます。
Hopper TMA 機能
- TMA を活用してデータ移動を高速化します。
共通の詳細最適化
- さまざまな最適化手法により性能を向上させます。
統合・最適化されたブロックスケジューラ
- すべての非グループ化およびグループ化カーネル向けのスケジューラを提供します。
完全な JIT 設計
- インストール時のコンパイルが不要な JIT 設計により性能を高めます。
非整列ブロックサイズ
- 特定の形状で SM 活用率を最大化するため、非整列ブロックサイズをサポートします。
FFMA SASS インターリービング
- 性能向上のために FFMA 命令を調整し、ワープレベル並列性を向上させます。
謝辞
- DeepGEMM は CUTLASS プロジェクトから着想を得ており、開発者たちに感謝と敬意を表します。
ライセンス
- MIT ライセンスで公開されています。
まだコメントはありません。