2 ポイント 投稿者 GN⁺ 2025-03-06 | まだコメントはありません。 | WhatsAppで共有
  • GPT式のデコーダー専用Transformerにおける学習可能なセルフアテンションは、各トークンが先行する入力のどのトークンに注目すべきかを計算し、コンテキストベクトルを作る
  • 核となるのはスケールド・ドットプロダクト・アテンションで、入力埋め込みをquery、key、value空間へ送る3つの学習行列 WqWkWv を使う
  • 入力行列 XQ=XWqK=XWkV=XWv に変換され、Ω=QKᵀ√c で割った後、行単位のsoftmaxでアテンション重み A を得る
  • コンテキストベクトルは C=AV という1回の行列積で作られ、全体の計算は5回の行列積と1回の転置ですべてのトークンに適用できる
  • この段階は、入力埋め込み同士で直接dot productを取っていたおもちゃの例を超え、PyTorchの nn.Modulenn.Linear で実装できる訓練可能なアテンションへとつながる

LLMの処理フローにおけるセルフアテンションの位置

  • GPT式のデコーダー専用TransformerベースのLLMは、これまでのトークンを見て次のトークンを予測する構造である
  • 処理フローは、文字列をトークンに分割し、各トークンをトークン埋め込みに変換した後、位置情報を表す位置埋め込みを加えて入力埋め込みを作る、という順序で進む
  • セルフアテンションは、各入力埋め込みについて、ほかのトークンにどれだけ注目するかを表すアテンションスコアのリストを生成する
    • 例文 "the fat cat sat on the mat""cat" を見るとき、"fat" は重要になり得る
    • "mat" を見るときには、"fat" の重要度は相対的に低い場合がある
  • アテンションスコアはsoftmaxを経て、合計が1になるアテンション重みとなり、この重みで入力埋め込みを加重和してコンテキストベクトルを作る
  • コンテキストベクトルは、各トークンの意味を入力全体の文脈の中で表現するベクトルとして扱われる

学習可能なセルフアテンションの目的

  • 前の段階までは、入力埋め込み同士で直接dot productを計算するおもちゃのセルフアテンションを使っていた
  • 今回の段階の目的は、入力ベクトルからアテンションスコアを作れる学習可能なアテンションメカニズムを構成することである
  • Sebastian RaschkaのBuild a Large Language Model (from Scratch) 3.4節では、これをscaled dot product attentionとして実装している
  • 焦点は、なぜこの構造が効果的なのかではなく、どのような計算で動作するのかに置かれている

Query、Key、Value行列と空間への射影

  • 入力シーケンス長を n、入力埋め込み次元を d、コンテキストベクトル次元を c とする
  • 入力埋め込みのシーケンスは x1, x2, x3, ... xn と表され、各入力埋め込みは d 次元ベクトルである
  • 3つの学習可能な重み行列を定義する
    • query weights matrix: Wq
    • key weights matrix: Wk
    • value weights matrix: Wv
  • 各行列は d×c サイズで、d 次元の入力ベクトルを c 次元空間へ射影する
  • 入力ベクトル xm をquery空間へ送る計算は qm=xmWq である
  • key空間とvalue空間も同じ方法で、入力埋め込みをそれぞれ異なる c 次元空間へ射影する

行列を射影として見る方法

  • 行列は点を回転させるなどの幾何学的変換に使える
  • 正方行列は同じ次元の中で変換を行い、正方でない行列はベクトルを別の次元空間へ送ることができる
  • 例えば 3×2 行列は、3次元ベクトルを2次元ベクトルに変換できる
  • 3Dグラフィックスで3D点を2D画面上の点に変換するfrustum行列も、このような射影の例として使われる
  • セルフアテンションは、入力埋め込みをquery、key、valueという3つの互いに異なる射影空間へ送ったうえで、射影されたベクトルを使って計算を進める
  • これらの射影行列は訓練中に学習されるため、単純なdot productアテンションにはなかった間接性が生じる

アテンションスコアの計算

  • 特定の入力 xm を考えるとき、別の入力 xp に対するアテンションスコアは、query射影とkey射影のdot productとして定義される
  • 計算式は次のとおり
    • qm=xmWq
    • kp=xpWk
    • ωm,p=qm·kp
  • すべての入力についてこの計算をループで処理することもできるが、行列積を使えば一度に計算できる
  • 入力埋め込み全体を行列 X とすると、Xn×d サイズである
  • key行列は K=XWk として一度に計算される
    • 結果 Kn×c サイズ
    • 各行は、その入力埋め込みをkey空間へ射影したベクトルである
  • query行列も同じ方法で Q=XWq として計算される
  • すべてのqueryとすべてのkeyの間のdot productは QKᵀ で得られる
    • Qn×c
    • Kᵀc×n
    • 結果 Ωn×n
  • Ωm,p は、xm のコンテキストベクトルを作るときに xp にどれだけ注目するかを表すアテンションスコアである

スケーリングとsoftmax正規化

  • アテンションスコアは以前の例と同様にsoftmaxを経て、合計が1の重みに変換される
  • softmaxは大きな値をさらに大きくし、小さな値を低くしながら、リスト全体の合計が1になるように調整する
  • 実際のLLMでは dc が数千単位になることがあり、純粋なsoftmaxだけを使うと小さなgradientが生じる可能性がある
  • この場合、softmaxが「step functionのように」振る舞うことがある
    • 最大値が支配し、残りの値が非常に小さくなる状況として解釈される
  • これを緩和するため、アテンションスコアを射影空間の次元 c の平方根で割った後にsoftmaxを適用する
  • 行列表現は次のとおり
    • A=softmax(Ω/√c, axis=1)
  • axis=1 はPyTorch式の表記で、softmaxを行単位で適用するという意味である
  • 結果 A は正規化されたアテンションスコア、つまりアテンション重み行列である

コンテキストベクトルの生成

  • value空間への射影は V=XWv で計算する
  • An×n サイズのアテンション重み行列である
    • Am,p は、xm のコンテキストベクトルを作るときに入力 p に適用するアテンション重みである
  • Vn×c サイズで、各行は入力埋め込みをvalue空間へ射影したベクトルである
  • コンテキストベクトル行列は C=AV で計算される
    • 結果 Cn×c
    • Cm 番目の行は、入力 xm に対するコンテキストベクトルである
  • この計算は、各トークンについてvalueベクトルをアテンション重みで掛けて足し合わせる作業を、1回の行列積で実行する

全体計算のまとめ

  • 入力行列 X はトークンシーケンスの入力埋め込みを含んでおり、サイズは n×d である
  • 3つの学習可能な行列で、入力をそれぞれquery、key、value空間へ射影する
    • Q=XWq
    • K=XWk
    • V=XWv
  • queryとkeyのdot productでアテンションスコアを計算する
    • Ω=QKᵀ
  • スコアをスケーリングした後、行単位のsoftmaxを適用してアテンション重みを作る
    • A=softmax(Ω/√c, axis=1)
  • value射影とアテンション重みを掛けてコンテキストベクトルを生成する
    • C=AV
  • セルフアテンションメカニズム全体は、5回の行列積と1回の転置で、すべての入力トークンのコンテキストベクトルを作ることができる

PyTorch実装と次のステップ

  • 書籍の3.4節では、上記の計算をPyTorchコードで実装し、同じ行列演算を行う簡単な nn.Module サブクラスを作る
  • 最初のバージョンでは、3つの重み行列に通常の nn.Parameter オブジェクトを使う
  • 2番目のバージョンでは、より効果的な訓練のために nn.Linear を使う
  • 以後扱うテーマは2つである
    • causal self-attention: 特定のトークンを見るとき、それ以降のトークンには注目しない方式
    • multi-head attention: 当初思っていたほど複雑ではないテーマとして予告されている
  • バッチ処理は別の検討事項として残っている
    • 単一の入力シーケンスでもアテンションスコア行列を使う
    • 複数の入力シーケンスを並列処理するには、行列より高い階数のテンソルが必要になる場合がある
  • 次の記事はWriting an LLM from scratch, part 9 -- causal attentionへ続く

まだコメントはありません。

まだコメントはありません。