LLM開発プロセス第8部 - 学習可能なセルフアテンション技術
(gilesthomas.com)- GPT式のデコーダー専用Transformerにおける学習可能なセルフアテンションは、各トークンが先行する入力のどのトークンに注目すべきかを計算し、コンテキストベクトルを作る
- 核となるのはスケールド・ドットプロダクト・アテンションで、入力埋め込みをquery、key、value空間へ送る3つの学習行列
Wq、Wk、Wvを使う - 入力行列
XはQ=XWq、K=XWk、V=XWvに変換され、Ω=QKᵀを√cで割った後、行単位のsoftmaxでアテンション重みAを得る - コンテキストベクトルは
C=AVという1回の行列積で作られ、全体の計算は5回の行列積と1回の転置ですべてのトークンに適用できる - この段階は、入力埋め込み同士で直接dot productを取っていたおもちゃの例を超え、PyTorchの
nn.Moduleとnn.Linearで実装できる訓練可能なアテンションへとつながる
LLMの処理フローにおけるセルフアテンションの位置
- GPT式のデコーダー専用TransformerベースのLLMは、これまでのトークンを見て次のトークンを予測する構造である
- 処理フローは、文字列をトークンに分割し、各トークンをトークン埋め込みに変換した後、位置情報を表す位置埋め込みを加えて入力埋め込みを作る、という順序で進む
- セルフアテンションは、各入力埋め込みについて、ほかのトークンにどれだけ注目するかを表すアテンションスコアのリストを生成する
- 例文
"the fat cat sat on the mat"で"cat"を見るとき、"fat"は重要になり得る "mat"を見るときには、"fat"の重要度は相対的に低い場合がある
- 例文
- アテンションスコアはsoftmaxを経て、合計が1になるアテンション重みとなり、この重みで入力埋め込みを加重和してコンテキストベクトルを作る
- コンテキストベクトルは、各トークンの意味を入力全体の文脈の中で表現するベクトルとして扱われる
学習可能なセルフアテンションの目的
- 前の段階までは、入力埋め込み同士で直接dot productを計算するおもちゃのセルフアテンションを使っていた
- 今回の段階の目的は、入力ベクトルからアテンションスコアを作れる学習可能なアテンションメカニズムを構成することである
- Sebastian RaschkaのBuild a Large Language Model (from Scratch) 3.4節では、これをscaled dot product attentionとして実装している
- 焦点は、なぜこの構造が効果的なのかではなく、どのような計算で動作するのかに置かれている
Query、Key、Value行列と空間への射影
- 入力シーケンス長を
n、入力埋め込み次元をd、コンテキストベクトル次元をcとする - 入力埋め込みのシーケンスは
x1, x2, x3, ... xnと表され、各入力埋め込みはd次元ベクトルである - 3つの学習可能な重み行列を定義する
- query weights matrix:
Wq - key weights matrix:
Wk - value weights matrix:
Wv
- query weights matrix:
- 各行列は
d×cサイズで、d次元の入力ベクトルをc次元空間へ射影する - 入力ベクトル
xmをquery空間へ送る計算はqm=xmWqである - key空間とvalue空間も同じ方法で、入力埋め込みをそれぞれ異なる
c次元空間へ射影する
行列を射影として見る方法
- 行列は点を回転させるなどの幾何学的変換に使える
- 正方行列は同じ次元の中で変換を行い、正方でない行列はベクトルを別の次元空間へ送ることができる
- 例えば
3×2行列は、3次元ベクトルを2次元ベクトルに変換できる - 3Dグラフィックスで3D点を2D画面上の点に変換するfrustum行列も、このような射影の例として使われる
- セルフアテンションは、入力埋め込みをquery、key、valueという3つの互いに異なる射影空間へ送ったうえで、射影されたベクトルを使って計算を進める
- これらの射影行列は訓練中に学習されるため、単純なdot productアテンションにはなかった間接性が生じる
アテンションスコアの計算
- 特定の入力
xmを考えるとき、別の入力xpに対するアテンションスコアは、query射影とkey射影のdot productとして定義される - 計算式は次のとおり
qm=xmWqkp=xpWkωm,p=qm·kp
- すべての入力についてこの計算をループで処理することもできるが、行列積を使えば一度に計算できる
- 入力埋め込み全体を行列
Xとすると、Xはn×dサイズである - key行列は
K=XWkとして一度に計算される- 結果
Kはn×cサイズ - 各行は、その入力埋め込みをkey空間へ射影したベクトルである
- 結果
- query行列も同じ方法で
Q=XWqとして計算される - すべてのqueryとすべてのkeyの間のdot productは
QKᵀで得られるQはn×cKᵀはc×n- 結果
Ωはn×n
Ωm,pは、xmのコンテキストベクトルを作るときにxpにどれだけ注目するかを表すアテンションスコアである
スケーリングとsoftmax正規化
- アテンションスコアは以前の例と同様にsoftmaxを経て、合計が1の重みに変換される
- softmaxは大きな値をさらに大きくし、小さな値を低くしながら、リスト全体の合計が1になるように調整する
- 実際のLLMでは
dとcが数千単位になることがあり、純粋なsoftmaxだけを使うと小さなgradientが生じる可能性がある - この場合、softmaxが「step functionのように」振る舞うことがある
- 最大値が支配し、残りの値が非常に小さくなる状況として解釈される
- これを緩和するため、アテンションスコアを射影空間の次元
cの平方根で割った後にsoftmaxを適用する - 行列表現は次のとおり
A=softmax(Ω/√c, axis=1)
axis=1はPyTorch式の表記で、softmaxを行単位で適用するという意味である- 結果
Aは正規化されたアテンションスコア、つまりアテンション重み行列である
コンテキストベクトルの生成
- value空間への射影は
V=XWvで計算する Aはn×nサイズのアテンション重み行列であるAm,pは、xmのコンテキストベクトルを作るときに入力pに適用するアテンション重みである
Vはn×cサイズで、各行は入力埋め込みをvalue空間へ射影したベクトルである- コンテキストベクトル行列は
C=AVで計算される- 結果
Cはn×c Cのm番目の行は、入力xmに対するコンテキストベクトルである
- 結果
- この計算は、各トークンについてvalueベクトルをアテンション重みで掛けて足し合わせる作業を、1回の行列積で実行する
全体計算のまとめ
- 入力行列
Xはトークンシーケンスの入力埋め込みを含んでおり、サイズはn×dである - 3つの学習可能な行列で、入力をそれぞれquery、key、value空間へ射影する
Q=XWqK=XWkV=XWv
- queryとkeyのdot productでアテンションスコアを計算する
Ω=QKᵀ
- スコアをスケーリングした後、行単位のsoftmaxを適用してアテンション重みを作る
A=softmax(Ω/√c, axis=1)
- value射影とアテンション重みを掛けてコンテキストベクトルを生成する
C=AV
- セルフアテンションメカニズム全体は、5回の行列積と1回の転置で、すべての入力トークンのコンテキストベクトルを作ることができる
PyTorch実装と次のステップ
- 書籍の3.4節では、上記の計算をPyTorchコードで実装し、同じ行列演算を行う簡単な
nn.Moduleサブクラスを作る - 最初のバージョンでは、3つの重み行列に通常の
nn.Parameterオブジェクトを使う - 2番目のバージョンでは、より効果的な訓練のために
nn.Linearを使う - 以後扱うテーマは2つである
- causal self-attention: 特定のトークンを見るとき、それ以降のトークンには注目しない方式
- multi-head attention: 当初思っていたほど複雑ではないテーマとして予告されている
- バッチ処理は別の検討事項として残っている
- 単一の入力シーケンスでもアテンションスコア行列を使う
- 複数の入力シーケンスを並列処理するには、行列より高い階数のテンソルが必要になる場合がある
- 次の記事はWriting an LLM from scratch, part 9 -- causal attentionへ続く
まだコメントはありません。