PyTorch Internals (2019)
(blog.ezyang.com)- PyTorchの内部構造に関する解説で、PyTorchのC++コードベースに貢献したい人のためのガイド
- この記事の目的は、PyTorchのテンソルライブラリ構造と自動微分(autograd)の仕組みを理解し、コードベース内で道筋を見つける助けになること
PyTorchテンソルの基本構造
- PyTorchではテンソルが最も基本的なデータ構造
- テンソルはn次元のデータ構造で、浮動小数点数(float)、整数(int)などのスカラー値を格納できる
- テンソルは次のようなメタデータを含む:
- サイズ(size): テンソルの次元情報
- dtype: 格納されるデータ型(例: float32、int64など)
- device: データが格納される場所(CPU、CUDAなど)
- stride: データの物理メモリ上でのオフセット情報
-
strideの役割
- strideは論理インデックスを物理メモリ位置に変換するために使われる
- strideは各次元ごとにオフセットを設定し、インデックスにstride値を掛けて物理メモリ位置を決定する
- strideにより、新しいテンソルを生成せずにviewとして同じデータを別の方法で見ることができる
テンソルとストレージ(Storage)の概念
- PyTorchではテンソルが実データを直接保持しない → データはストレージ(Storage)で管理される
- Tensor = サイズ + dtype + device + stride + offset 情報
- 複数のテンソルが1つのストレージを共有可能 → Viewの概念をサポート
- ストレージとテンソルの分離により、メモリを効率的に使える
テンソル演算のディスパッチ(Dispatch)プロセス
- PyTorchでは演算は2段階のディスパッチを経る:
- デバイスタイプおよびレイアウトベースのディスパッチ
- CPUテンソルかCUDAテンソルかによって異なる実装コードが実行される
- dtypeベースのディスパッチ
- floatかintかなど、データ型に応じて異なるカーネルが呼び出される
- デバイスタイプおよびレイアウトベースのディスパッチ
PyTorchテンソルの拡張モデル
-
テンソルの3つの主要な拡張要素:
- Device: CPU、GPU、TPUなどでのメモリ割り当て方式を定義
- Layout: テンソルがメモリに格納される方式を定義(例: 連続格納、疎(sparse)格納など)
- dtype: テンソルの各要素に格納されるデータ型を定義
-
拡張オプション:
- PyTorchコードを直接修正してテンソルを拡張できる
- 既存のテンソルを包むラッパークラスを書くこともできる
- 自動微分の途中でラッパーが必要なら、直接拡張が必要
自動微分(Autograd)の動作原理
- PyTorchは**逆伝播(reverse-mode differentiation)**に基づいて自動微分を行う
- 順伝播(forward)演算時にグラフを生成し、逆伝播時にそのグラフをたどって微分を行う
- Autogradは次のような追加情報を管理する:
- AutogradMeta: テンソルに紐づくメタデータで、逆伝播に使われる
- 演算結果を記録し、逆伝播時に微分を実行する
PyTorchのコード構造とファイル配置
- PyTorchコードベースの主要ディレクトリ:
torch/→ Pythonモジュール(Pythonコード)torch/csrc/→ PythonとC++のバインディングコード、自動微分エンジン、JITコンパイラなどaten/→ テンソル演算の定義(主要なコア演算の大半を含む)c10/→ テンソルやストレージのようなコア構造体の定義
PyTorch演算の実行プロセス
- 例:
torch.add()呼び出し時の実行プロセス:- PythonからC++コードへ引数を変換
- VariableTypeでディスパッチを実行
- Device/レイアウトベースのディスパッチを実行
- 最終カーネルを実行
カーネル作成プロセスとツール
- PyTorchでカーネルは次のような段階で作成される:
- 演算メタデータの作成: 関数シグネチャ、対応デバイスおよびデータ型を定義
- 入力検証: 次元や型などの入力検証を実行
- 出力テンソルの割り当て
- dtypeディスパッチ: データ型に応じてカーネルを実行
- 並列処理: CPUではOpenMP、CUDAでは組み込み並列化を使用
- データアクセスと計算: TensorAccessor、TensorIteratorなどを使用
主なディスパッチマクロ
- AT_DISPATCH_ALL_TYPES → dtypeに応じたディスパッチを実行
- 多様なデータ型向けのマクロが用意されており、性能最適化が可能
性能最適化と作業効率向上のヒント
- ヘッダーファイルの修正を最小限にする → 修正時にコード全体の再ビルドが発生
- ローカル開発環境を整える → CI利用時の時間消費を最小化
- ccacheを使う → 再コンパイル時間を節約可能
- 高性能なサーバーを使う → C++コンパイルやCUDAビルドの時間短縮が可能
PyTorchへの貢献ガイド
- 始めやすい貢献対象:
- triagedラベル付きのIssue → PyTorch開発者が確認済みのIssue
- ドキュメント改善やバグ再現の支援
- PyTorchのRFC(機能提案)への意見提示
- PyTorchはオープンソース貢献者によって成長してきており、コミュニティ参加を歓迎している
まだコメントはありません。