MLX - Appleシリコン向けのNumPy風配列フレームワーク
(github.com/ml-explore)- Appleの機械学習研究チームが開発した、Appleシリコン上で効率的かつ柔軟に機械学習を実行するために設計された配列フレームワーク
- NumPyとほぼ同様のPython APIを提供し、同等の機能を備えたC++ APIも用意
- NumPyとの違い
- Composable function transformations: MLXには、自動微分、自動ベクトル化、計算グラフ最適化のための構成可能な関数変換がある
- Lazy Computation: MLXの計算は遅延評価され、配列は必要なときにのみ具体化(Materialize)される
- マルチデバイス: 対応するすべてのデバイス(CPU、GPU、...)で演算を実行できる
- 動的グラフ構築: MLXの計算グラフは動的に構築される。関数引数の形状を変更してもコンパイル速度は低下せず、デバッグも簡単で直感的
- PyTorch、Jax、ArrayFireのようなフレームワークから着想を得ている
- これらのフレームワークとMLXの目立つ違いは、Unified Memory Model
- MLXの配列は共有メモリに保存される。MLX配列に対する操作は、データコピーを行わずに対応するすべてのデバイスタイプで実行できる
- 現在サポートされているデバイスタイプはCPUとGPU
- さまざまなサンプルを含む
- Transformer言語モデルの学習
- LLaMAによる大規模テキスト生成とLoRAによるファインチューニング
- Stable Diffusionによる画像生成
- OpenAI's Whisperによる音声認識
2件のコメント
これは本当に良さそうですね。動的なデータを多く扱うので、jaxを使うときはいつも大変だったのですが…
おお、これはいいですね。一度使ってみます。