Google Jax - 高性能な機械学習ライブラリ
(github.com)「使いやすいものを高速に作って機械学習に適用」
- PythonとNumpyだけを組み合わせ
→ XLA を使って Numpy を GPU/TPU 上でコンパイルして実行
→ Python 関数を 1 つの API で JIT コンパイルし、XLA 最適化されたカーネルに簡単に載せられる
→ 複数の GPU/TPU での実行も簡単(vmap, pmap)
- 既存の Python+Numpy の性能をはるかに上回る
「使いやすいものを高速に作って機械学習に適用」
→ XLA を使って Numpy を GPU/TPU 上でコンパイルして実行
→ Python 関数を 1 つの API で JIT コンパイルし、XLA 最適化されたカーネルに簡単に載せられる
→ 複数の GPU/TPU での実行も簡単(vmap, pmap)
1件のコメント
DeepMindはJaxベースで全体をリファクタリングしたとのこと。
https://deepmind.com/blog/article/using-jax-to-accelerate-our-research