- RustGPT は、外部の機械学習フレームワークを使わず、純粋なRustと
ndarrayだけで実装された トランスフォーマーベースの言語モデル
- 事前学習(Pre-training) と 指示チューニング(Instruction tuning) を通じて、事実ベースの知識と対話パターンを学習するよう設計されている
- 構造は トークナイザー → 埋め込み → トランスフォーマーブロック → 出力射影 へと続く典型的な LLMアーキテクチャ に従う
- モジュール化されたソース構造 と テストコード を備え、学習・推論・最適化の過程を細部まで理解できる
- Rustエコシステムで フレームワークに依存せずLLMをゼロから実装 してみたい開発者や学習者にとって重要な参考資料
プロジェクト概要
- RustGPTは、外部の機械学習フレームワークや複雑な依存関係なしに、純粋な Rust言語と線形代数演算ライブラリ(ndarray) だけでLLMを実装したオープンソースプロジェクト
- 主な目標は、現代的なLLMの中核構成要素(トランスフォーマー、アテンション、埋め込み、最適化など)を自ら実装し、学習過程を理解すること
- 他の主流LLMと異なり、トランスフォーマー構造、バックプロパゲーション、トークナイザー、オプティマイザなどをすべてRustコードで設計しており、Rust開発者や研究者がディープラーニングの原理をゼロから理解し拡張できる点が大きな強み
- ndarray で行列演算を処理し、PyTorchやTensorFlowのような外部機械学習パッケージに依存しないことが差別化要因
- モジュール化とテストカバレッジがしっかりしており、多様な実験や改善に適していて、「ゼロから自分で作るLLM(From Scratch)」の教育用途にも向いている
主な特徴と実装方式
- トランスフォーマーアーキテクチャ: 入力テキスト → トークン化 → 埋め込み → トランスフォーマーブロック → 最終予測
- 入力テキストはトークン化過程を経て埋め込みベクトルに変換される
- 埋め込みは Transformer Block(マルチヘッドアテンション + フィードフォワードネットワーク)を通過する
- 最後に Output Projection Layer で語彙の確率分布を生成し、予測を行う
実装構造
main.rs: 学習パイプライン、データ準備、インタラクティブモードの実行
llm.rs: LLM全体の順伝播・逆伝播および学習ロジック
transformer.rs, self_attention.rs, feed_forward.rs: 中核となるトランスフォーマーブロック
embeddings.rs, output_projection.rs: 埋め込みおよび最終出力層
adam.rs: Adamオプティマイザの実装
- 各モジュールには対応する テストコード(
tests/)が含まれており、機能検証が可能
学習・テスト方法とデータフロー
- 学習過程
- 語彙集の生成 → 事前学習(100epoch、事実文データ) → Instructionチューニング(100epoch、対話データ)
- 事前学習の例: "The sun rises in the east and sets in the west"
- Instructionチューニングの例: "User: How do mountains form? Assistant: ..."
- インタラクティブモード対応
- 学習完了後、プロンプト-応答ベースの対話テストが可能
- 例: "How do mountains form?" → "Mountains are formed through tectonic forces or volcanism..."
技術的な詳細構成
- 語彙サイズ: トレーニングデータに基づいて動的に設定
- 埋め込み次元: 128、隠れ層: 256
- 最大シーケンス長: 80トークン
- アーキテクチャ: 3つのトランスフォーマーブロック + 埋め込み + 出力層
- 学習アルゴリズム: Adamオプティマイザ、gradient clipping(L2 norm 5.0制限)
- 学習率: pre-training 0.0005、instruction tuning 0.0001
- 損失関数: cross-entropy loss
モデルとコードの特徴
- カスタムトークナイザー(句読点処理)
- グリーディデコーディング ベースのテキスト生成
- モジュール型の階層構造 と明快なインターフェース
- テストカバレッジ: 各層・各機能ごとの単体テストコードを搭載
- 依存関係: ndarray(行列演算)、rand / rand_distr(乱数初期化)のみ使用(PyTorch / TensorFlowなど外部MLは未使用)
- 教育的価値: 代表的な現代LLMの内部構造・訓練原理の学習に最適
発展可能性
- 高度なアーキテクチャ導入: マルチヘッドアテンション、RoPE、位置エンコーディングなど
- 性能最適化: SIMD、並列学習、メモリ効率の改善
- モデル保存 / 読み込み対応
- 改良されたサンプリング(ビームサーチ、Top-k / Top-p)および評価指標の追加
意義
- PythonベースのPyTorch、TensorFlowフレームワークに依存せずとも、RustだけでLLMを直接実装できる ことを示す学習用・実験用プロジェクト
- LLMの内部動作原理を理解し、Rust環境でMLシステムを作りたい開発者にとって有用なリファレンス
まだコメントはありません。