- トランスフォーマーの推論過程を Hello World → Hola Mundo の翻訳例に縮小し、トークン化からエンコーダー・デコーダー・次トークン確率の計算までを手で追えるようにしている
- 元論文の大きな設定の代わりに 4次元埋め込み、2つのアテンションヘッド、8次元フィードフォワード層を使い、行列積と softmax の流れを小さくしている
- エンコーダーはトークン埋め込みに 位置エンコーディング を加えた後、マルチヘッド self-attention とフィードフォワード層を経て、入力シーケンスの文脈表現を作る
- デコーダーは
SOS から開始し、これまでに生成したトークンとエンコーダー出力を一緒に使い、encoder-decoder attention では query はデコーダー、key/value はエンコーダー出力から計算する
- 最後のデコーダー埋め込みは線形層と softmax を通って次トークン確率になるが、例では ランダムな重み を使っているため実際の翻訳品質は期待しない
目標と前提
- トランスフォーマーモデル内部で 推論時の数学 がどのようにつながるかを end-to-end の例で確認する
- 計算を手で追いやすくするため、モデルサイズを大幅に縮小している
- 元論文の埋め込み次元 512 の代わりに、例では 4次元 を使用
- アテンションヘッドは元論文の 8 個の代わりに 2個 を使用
- フィードフォワード次元は元論文の 2048 の代わりに 8次元 を使用
- 必要な前提は基本的な線形代数であり、計算の大部分は行列積で進む
- トランスフォーマーが「何か」よりも、実際の計算が どう進むか に焦点を当てている
- 直感的な説明は The Illustrated Transformer と併せて読むとよく、元論文は Attention is all you need
エンコーダー入力の作成
-
トークン化
- 機械学習モデルはテキストではなく数値を処理するため、入力テキストを トークン ID に変換する
- 例では単純化のため
"Hello World" を "Hello" と "World" の 2 つの単語トークンに分ける
- 実際のトークン化方式は、単語ベース、文字ベース、subword ベースに分けられる
- 単語ベースは大きな vocabulary が必要で、
"dog" と "dogs" を別々のトークンとして扱う
- 文字ベースは vocabulary が小さい一方で、意味情報が乏しくなることがある
- subword トークン化は単語方式と文字方式の中間にあり、統計的プロセスでトークナイザーを学習する
-
トークン埋め込み
- トークン ID 自体には意味がないため、各トークンを固定サイズベクトルである 埋め込み に変換する
- 例の埋め込みは任意の値を使用
Hello -> [1, 2, 3, 4]
World -> [2, 3, 4, 5]
- 実際のトランスフォーマーでは埋め込みマッピングも学習され、モデルがタスクに適したトークン表現を学習する
- 2 つの埋め込みは 1 つの行列にまとめられ、その後の行列積に使われる
-
位置エンコーディング
- 埋め込みだけでは単語の 文中での位置 が分からないため、位置エンコーディングを加える
- 元論文は固定の sine/cosine 位置エンコーディングを使用しており、例でも同じ方式を採用する
- 例の位置エンコーディングは次のように計算される
Hello -> [0, 1, 0, 1]
World -> [0.84, 0.99, 0, 1]
- トークン埋め込みと位置エンコーディングを足し合わせて、エンコーダー入力行列を作る
Hello -> [1, 3, 3, 5]
World -> [2.84, 3.99, 4, 6]
Self-attention の計算
-
Q, K, V の作成
- self-attention は入力埋め込みから query(Q)、key(K)、value(V) を計算する
- 例では 2 つのアテンションヘッドを使い、各ヘッドは独立した
WQ, WK, WV 行列を持つ
- 各重み行列は 4 次元埋め込みを 3 次元の query/key/value に変換する
- 1 つ目のヘッドでは、入力行列と
WK1, WV1, WQ1 を掛けて K1, V1, Q1 を得る
-
Attention の公式
- アテンションスコアは 4 段階で計算される
- query と各 key の 内積 を計算する
- key 次元の平方根で割る
- softmax で正かつ総和が 1 の重みに変換する
- 重みで value ベクトルの加重和を取る
- この過程は元論文の公式に要約される
- [
- Attention(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V
- ]
- 例では小さな次元と任意の初期値のため、softmax の結果はほぼ 0 と 1 に偏る
- 大きな dot product 値は softmax でより強く増幅されうるため、key 次元の平方根で割るスケーリングが必要になる
- 説明のために一時的に
sqrt(3) の代わりに 30 で割る変形も使うが、長期的な解決策ではない
-
マルチヘッドアテンション出力
- 各ヘッドのアテンション結果を concatenate した後、学習される重み行列を掛けて再び埋め込み次元へ戻す
- 例では 2 つのヘッド結果を合わせて 6 次元行列を作り、これを 4 次元出力に変換する
- この出力はエンコーダーブロックの次の段階であるフィードフォワード層へ渡される
フィードフォワード層とエンコーダーブロック
-
フィードフォワード層
- self-attention の後には フィードフォワードニューラルネットワーク(FFN) が置かれる
- FFN は 2 つの線形変換と、その間の ReLU 活性化で構成される
- 1 つ目の線形層は次元を拡張し、2 つ目の線形層は次元を元のサイズに縮小する
- ReLU は負の値を 0 にし、正の値はそのまま維持して非線形性を加える
- 例では 4 次元入力を 8 次元に拡張した後、再び 4 次元に縮小する
- [
- \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
- ]
-
エンコーダーブロック
- 1 つのエンコーダーブロックはマルチヘッドアテンションと FFN で構成される
- 元論文では 6 個のエンコーダー を積み重ねており、例のコードでも
n=6 としてエンコーダーを繰り返す
- 複数のエンコーダーブロックを単純に通すと値が大きくなりすぎ、softmax 計算で overflow が発生して
nan になることがある
Residual connection と layer normalization
-
値の暴走問題
- 例で 6 個のエンコーダーを通したところ、
overflow encountered in exp と invalid value encountered in divide の警告が発生し、出力が nan になった
- 値が大きくなりすぎ、次の層でさらに増幅される現象は深いニューラルネットワークでよくある問題
- backpropagation 中に gradient が大きくなりすぎる場合は gradient explosion と呼ばれる
-
Residual connection
- residual connection は層の入力を層の出力に足し合わせる方式
- [
- \text{Residual}(x) = x + \text{Layer}(x)
- ]
- 例ではアテンション出力と FFN 出力にそれぞれ residual connection を適用する
- residual connection は vanishing gradient 問題の緩和に使われる
-
Layer normalization
- layer normalization は各埋め込み次元について平均 0、標準偏差 1 になるよう正規化する
- 式は次の通り
- [
- \text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \times \gamma + \beta
- ]
- (\epsilon) は標準偏差が 0 のときの 0 除算を避けるための小さな値
- (\gamma) と (\beta) は scaling と shifting を制御する学習パラメータ
- residual connection と layer normalization を追加した後は、6 個のエンコーダーを通しても
nan なしで正常な値が得られる
デコーダー構造
-
デコーダー入力と生成方式
- デコーダーはエンコーダー出力と、これまでに生成した出力シーケンスを入力として受け取る
- 推論中は SOS(start-of-sequence) トークンから始める
- デコーダーは autoregressive な方式で 1 回に 1 トークンずつ生成する
- 1 回目:
SOS を入力として受け取り "hola" を生成
- 2 回目:
SOS + hola を入力として受け取り "mundo" を生成
- 3 回目:
SOS + hola + mundo を入力として受け取り EOS を生成
EOS(end-of-sequence) トークンが生成されるとデコードを停止する
- エンコーダーは 1 回の forward pass で表現を作れるが、デコーダーは複数回の forward pass が必要なため遅い
-
デコーダーブロック構成
- デコーダーブロックはエンコーダーブロックより複雑で、次の順序で構成される
- masked self-attention
- residual connection と layer normalization
- encoder-decoder attention
- residual connection と layer normalization
- フィードフォワード層
- residual connection と layer normalization
- 推論例では
SOS 埋め込みに位置エンコーディングを加えて [1, 1, 0, 1] を使用する
- 学習中は未来のトークンを見られないよう、attention score を
-inf でマスクする masked self-attention を使う
Encoder-decoder attention
- encoder-decoder attention は、デコーダーが入力文の関連部分に集中できるようにする段階
- self-attention と計算方法は同じだが、Q/K/V を作る入力が異なる
- query は直前のデコーダー層出力から計算する
- key と value は エンコーダー出力 から計算する
- この構造により、デコーダーの各位置が入力シーケンスのすべての位置を参照できる
- 翻訳のように、出力トークンが入力文中の関連位置に依存する必要があるタスクで有用
出力トークン生成
-
Linear layer と softmax
- デコーダー出力はそのまま単語ではないため、最後の埋め込みを線形層に通して vocabulary サイズの logits ベクトルに変換する
- 例の vocabulary サイズは 10 で、次トークン候補は以下の通り
hello, mundo, world, how, ?, EOS, SOS, a, hola, c
- logits は softmax を経て各トークンの確率分布になる
- 例の確率では
"hola" が最も高い確率を持ち、次トークンとして選ばれる
- 常に最も高い確率のトークンを選ぶ方式は greedy decoding であり、常に最善とは限らない
- 生成手法は Hugging Face の記事 でさらに詳しく見られる
-
全体の生成ループ
- 全体の生成手順は次の流れに従う
- 入力シーケンスを埋め込みに変換する
- エンコーダーが入力全体の文脈表現を生成する
- デコーダーは
SOS から始め、以前に生成したトークンとエンコーダー出力を一緒に使う
- 最後のデコーダー埋め込みに linear layer と softmax を適用する
- 最も可能性の高い次トークンを選び、シーケンスに追加する
EOS が出るか最大長に到達するまで繰り返す
- 例の実行では
hello world 入力に対して SOS hola mundo world を生成する
- すべての重みと埋め込みをランダムに使っているため、結果は良い翻訳にはならず、これは想定通りの挙動
結論と範囲
- この例は、トランスフォーマーの中核構成要素である埋め込み、位置エンコーディング、self-attention、マルチヘッドアテンション、FFN、residual connection、layer normalization、encoder-decoder attention、softmax 出力を 1 つの流れとしてつなげている
- 最新のトランスフォーマーアーキテクチャはさまざまな手法を追加しているが、中核となる数学はこの例で扱った構造に基づいている
- タスクの種類によって使うスタックは異なる場合がある
- 分類のような理解中心のタスクでは、エンコーダースタックの上に linear layer を置ける
- 翻訳のような生成中心のタスクでは、エンコーダーとデコーダースタックを一緒に使える
- ChatGPT や Mistral のような自由生成タスクでは、デコーダースタックだけを使える
- 学習過程は扱わず、既存モデルを使う際の 推論の数学 を理解することに焦点を当てている
- より形式的な数学資料としては この PDF を参照できる
1件のコメント
Hacker News のコメント
Transformer の「ミステリー」は、各層で静的な重みと値を線形順序で掛け合わせるのではなく、同じ入力に学習済みの重みを掛けて得た 3つの行列を作り、その行列同士を掛け合わせる点にある。
並列性が高まるためうまく機能するが、Attention の式自体は固定されているので、非常に制約が大きい。
さらに発展するには、計算グラフそのものを学習可能なパラメータへ一般化する方法が必要に見える。小さな変化が性能の大きな変化につながるカオス効果のため、従来の勾配法で可能なのかは分からず、内部的には遺伝的アルゴリズムや粒子群最適化のような形が必要になるかもしれない。
RNN と比べた大きな理論的利点は、これを損失なしでサポートする点だ。各要素がシーケンス内の他のすべての要素、あるいは時間順では先行する要素の全情報にアクセスできるためだ。
一方で RNN や「線形 Transformer」は過去の値を圧縮するため、長いシーケンスの最後の要素が最初の要素の全情報にアクセスするのは通常難しく、内部状態が非常に大きく、情報をまったく捨てない場合でなければ不可能だ。
問題は、そこから得られるものがあまりないことだ。行列乗算ではない演算は、より遅いか、せいぜい同程度の速度である可能性が高い。
ただしフロー制御を入れると、事実上 Turing machine になってしまうリスクがあり、そうなると述べられている通り学習が問題になる。それでも完全に手に負えない問題ではないかもしれない。
もっと乾いた、形式的で簡潔な説明が欲しいなら、John Thickstun の “The Transformer Model in Equations” [0] がある。
標準的な数学表記で、全体が1ページに収まっている。
[0] https://johnthickstun.com/docs/transformers.pdf
機械学習研究者は数学をまったく勉強していないように見えることが多い。
「NaN が出る、値が大きすぎて次のエンコーダへ渡るときに爆発する、これが勾配爆発だ」という説明は、私の理解では間違っている。
ここではどの時点でも勾配を計算していないので、勾配爆発ではない。
問題は softmax の実装側にあるようで、数値的に安定した softmax の実装方法はここ [0] で説明されている。
[0]: https://jaykmody.com/blog/stable-softmax/
ただしニューラルネットワーク全体が大きな値に敏感なので、数値的に安定した softmax だけでは解決しない。ネットワークが機能するには正規化が重要だ。
Transformer のチュートリアルは、新しい Monad チュートリアルになるかもしれない。理解しにくい概念だが、例を練習しながら格闘して初めて分かる種類のものだ。
コンピュータサイエンスの多くがそうであるように。
6段落しか読んでいないのに、もう疑問が湧いた。
Hello -> [1,2,3,4] World -> [2,3,4,5]で、ベクトルはランダムだというが、パターンがあるように見える。両方のベクトルにある2に何か意味があるのか、それとも集合全体が一意性を作っているのか気になる。ここでは約 60度 離れていて、ある程度同じ方向だが、例に負の数を入れないようにしたため、実際よりベクトル同士が似てしまっている影響が大きい。
数字が再利用されているという事実自体には意味がない。1番目の位置の
1は、2番目の位置の1とほとんど関係がない。このベクトル上で畳み込みをしているわけでもないからだ。学習後には似た単語同士がある程度のコサイン類似度を持つようになるが、
[1,2,3,4]と[2,3,4,5]ほどコサイン類似度が高くなることはほとんどない。完全に関連した質問というわけではないが、Transformer が単なる「次トークン予測器」のように動作しながらも、次のような質問を処理できる理由を扱った記事や論文を探している
"sdsfs_ff","fsdf_value"を列に持つ表を作るよくある質問だと思うのだが、検索するキーワードが見つからない。位置埋め込みについて深く扱ったリンクもあるとうれしいし、サイン/コサインを使う理由や、乗算対加算についてもまだ納得のいく答えを得られていない
モデルが必要だと判断すれば、未知のシーケンスを単一文字トークンをコピーして再現したり、文脈上自然であれば新しく作り出したりできる
P(X_1=x_1, X_2=x_2, X_3=x_3) = P(X_3=x_3 | X_1=X_1, X_2=x_2) • P(X_1=x_1, X_2=x_2)= P(X_3=x_3 | X_1=X_1, X_2=x_2) • P(X_2=x_2 | X_1=x_1) • P(X_1=x_1)つまり、前のトークン群が与えられたときの次トークンに対する正しい条件付き確率分布があれば、トークン列全体に対する正しい確率分布も作られる
「トークン列に対する正しい確率分布」、または何らかの条件が与えられたトークン列の正しい条件付き確率分布は、実質的にほぼあらゆる種類の入出力動作をそういう言葉で説明できる
だから「次トークンを予測して動作する」ということは、原理的にはどんな入出力の振る舞いができるかに対する大きな制約ではない
どれほど印象的なことをしても、その出力が
P(X_{n+1}=x_{n+1} | X_1=x_1, ..., X_n=x_n)から出てくる、つまり「次トークン予測」であるという事実とは矛盾しない次トークン予測は、聞こえよりもはるかに知的な作業である
「複雑さはステップ数とパラメータ数から来る」という言葉に同意する
私たちが理解できるほど単純な Transformer モデルは面白いことができず、面白いことができるほど複雑な Transformer は私たちが理解するには複雑すぎるように見える
理解できるほど単純でありながら、面白いことができるほど複雑な中規模モデルを研究してみたい
概念を定義したり紹介したりせずに使うと理解しにくい。Encoder セクションは、それが何なのか、全体のプロセスのどこに位置するのかの説明なしにいきなり始まる
著者が何をしようとしているのかは分かるが、アイデアを先に紹介して説明してから使うという基本的な文章構成が抜けている
すでにこのテーマを勉強中で半分くらい理解している人でなければ、記事全体が混乱して感じられる
ANN を一から書いたことがあり、TensorFlow は使っていないにもかかわらず、この説明は依然として混乱する
ChatGPT に、
MatrixやVectorという言葉を使わずに、基本的な ANN を self-attention を実装するように変える方法を説明してもらったところ、かなり単純な説明を返してくれた。まだ実装はしていないすべてをノード、重み、層の観点で考えるほうが自分には合っている。行列とベクトルは、ANN 内部で実際に起きていることと結び付けるのを難しくする
慣れ親しんだ ANN の書き方では各入力ノードはスカラーだが、順伝播アルゴリズムはすべての入力ノードに重みを掛けて足し合わせるので、ベクトル-行列積のように見える。こうした説明に間違った心構えで向き合っている気がするし、必要な背景知識が不足しているのかもしれない