1 ポイント 投稿者 GN⁺ 2024-07-14 | まだコメントはありません。 | WhatsAppで共有
  • AlphaFold3は、単一タンパク質を超えて、タンパク質・核酸・低分子が共存する複合体を配列だけから予測しようとしており、このためAF2よりも入力表現とトークン化がはるかに複雑になっている
  • 入力はトークンレベルのsingle/pair表現、原子レベル表現、MSA、templateに分かれ、標準アミノ酸・ヌクレオチドは1トークン、非標準残基とその他の分子は原子ごとに1トークンとして処理される
  • 表現学習trunkは、template module、MSA module、Pairformerを通じて、pair-bias attention、triangle演算、recyclingによりsingle表現sとpair表現zを反復的に改善する
  • 構造予測では、AF2のInvariant Point Attentionの代わりに、原子座標に対する条件付き拡散モデルを使用し、回転・並進の拡張とdenoisingによって全原子の座標更新を生成する
  • 学習では、distogram、diffusion、confidence lossを組み合わせ、AF2・AF-Multimerの結果を活用したcross-distillationにより、低信頼領域のunfolded表現まで再学習する

AlphaFold3の入力範囲と全体パイプライン

  • AlphaFold3の目標は、AF2のように個々のタンパク質配列だけを予測したり、AF-Multimerのようにタンパク質複合体だけを扱ったりすることにとどまらず、タンパク質と、必要に応じて他のタンパク質、核酸、低分子が結合した構造を配列だけから予測することにある
  • 「トークン」の意味は入力の種類によって異なる
    • タンパク質: 標準アミノ酸1個が1トークン
    • DNA/RNA: 標準ヌクレオチド1個が1トークン
    • 非標準アミノ酸・ヌクレオチド: 原子1個が1トークン
    • その他の分子: 原子1個が1トークン
  • 標準アミノ酸35個からなるタンパク質は、実際には600個を超える原子を持ちうるが、35トークンで表現され、原子35個のligandは35トークンで表現される
  • モデルは大きく3段階で構成される
    • Input Preparation: ユーザー入力の配列と検索された関連配列・構造を数値テンソルに変換
    • Representation Learning: single表現とpair表現を複数のattention変種で更新
    • Structure Prediction: 条件付き拡散で構造を予測
  • タンパク質複合体は主に2つの表現に保存される
    • single representation: 複合体の全トークンそのものを表現
    • pair representation: 全トークン対の間の距離や潜在的相互作用のような関係を表現
  • 主なチャネル次元は c_z=128, c_m=64, c_atom=128, c_atompair=16, c_token=768, c_s=384 である

入力準備: 配列を6個のテンソルに変換する過程

  • ユーザーが提供した入力は、モデルtrunkに入る6個のテンソルに変換される
    • s: token-level single representation
    • z: token-level pair representation
    • q: atom-level single representation
    • p: atom-level pair representation
    • m: MSA representation
    • t: template representation
  • MSAとtemplate検索

    • AF3はタンパク質とRNA配列に対して類似配列を探し、これをMSAとして構成し、関連構造はtemplateとして含める
    • MSAは、複数の種で見つかる類似タンパク質配列を整列し、特定位置の保存パターンと、異なる位置間の変化の相関関係をモデルに提供する
    • 類似タンパク質の既知構造は、homology modelingのようにqueryタンパク質の構造推定に使われる
    • 検索には学習は含まれず、HMMベースの手法が使われる
    • jackhmmer, HHBlits, nhmmer で複数のタンパク質・RNAデータベースを検索し、hmmsearch でProtein Data Bank内の類似配列を探す
    • MSAサイズは計算複雑性のため N_MSA < 2^14 に制限される
    • 各タンパク質chainでは品質の高い構造を選び、最大4個をtemplateとしてサンプリングする
    • AF-Multimerと比べて新たに追加された検索要素は、RNA配列も検索対象に含める点である
  • template表現方式

    • templateの3D構造から、各トークン対の間のユークリッド距離を計算する
    • 複数の原子を持つトークンには代表となる「center atom」を使う
      • アミノ酸: 原子
      • 標準ヌクレオチド: C1' 原子
    • 距離値は連続値ではなくdistogramとして離散化される
      • 3.15Åから50.75Åまで38個のbin
      • それより大きい距離のための追加binが1個
    • distogramにはchain情報、crystal structureで該当トークンがresolvedされているかどうか、各アミノ酸内部のlocal distance情報が追加される
    • template matrixは同一chain内の距離だけを見るようにmaskingされ、template選択によってinter-chain interaction情報を得ようとはしていない

原子レベル表現とAtom Transformer

  • reference conformerとatom-level表現

    • 原子レベルのsingle表現 q を作るために、各アミノ酸、ヌクレオチド、ligandについて reference conformer を計算する
    • conformerは、単結合の周りの回転をサンプリングして生成される分子の3D原子配置である
    • 標準アミノ酸にはlookupで得られる低エネルギーconformerを使い、小分子には RDKit’s ETKDGv3 で3D conformerを生成する
    • conformerの相対位置、原子電荷、原子番号、識別子などを結合して、atom-level single representation c を作る
    • c でatom-level pair representation p を初期化し、reference conformerで計算された原子間距離だけを含むようにmask v を使う
    • qc のコピーから始まり、その後Atom Transformerで更新される
  • Atom Transformerの役割

    • Atom Transformerは原子レベルのattentionを行うモジュールで、p と元の表現 c を使って q を更新する
    • c は更新されず、初期表現へ向かうresidual connectionのように使われる
    • 基本構造はtransformerと似ており、LayerNorm、attention、MLP transitionを含むが、各段階は cp の追加入力によって調整される
  • Adaptive LayerNorm

    • Adaptive LayerNormは、固定の gammabeta を学習する代わりに、補助入力から gammabeta を生成する
    • Atom Transformerでは、再スケーリング対象は q であり、再スケーリングパラメータは補助入力 c から予測される
  • Attention with Pair Bias

    • Atom-level attention with pair biasはself-attentionの拡張である
    • query、key、valueはすべてsingle representation q から出るが、query-key dot productの後にpair representation p の線形projectionをbiasとして加える
    • pair representationから q へ情報は流れるが、この段階では q の情報で p を更新しない
    • 追加projectionをsigmoidに通して作った gate がattention結果に掛け合わされ、residual streamにどの情報を残すかを調整する
    • 原子数はトークン数よりはるかに多くなり得るため、full attentionの代わりに Sequence-local atom attention を使う
    • 32個の原子単位のlocal groupが、128個の別の原子にattendできる
  • Conditioned GatingとTransition

    • Conditioned Gatingは、元のatom-level single matrix c から生成したgateをデータに適用する
    • Conditioned TransitionはtransformerのMLPに相当し、Adaptive LayerNormとConditional Gatingが c に依存するためconditionedと呼ばれる
    • AF3はtransition blockでReLUの代わりに SwiGLU を使う
    • AF2のReLUベースのtransitionは、4倍のup-projection、ReLU、down-projectionという構造である
    • AF3のSwiGLUは、2つのup-projectionのうち一方にswish非線形を適用してから掛け合わせ、down-projectする

原子表現をトークン表現へ集約

  • 表現学習段階はこの後token-levelで動作するため、atom-level表現をtoken-level表現へ集約する
  • atom-level representationをより大きい次元へprojectionした後、同じトークンに属する原子の平均を取る
  • この平均集約は、標準アミノ酸やヌクレオチドのように複数の原子が1つのトークンに結び付いている場合に適用され、原子あたり1トークンの入力はそのまま維持される
  • token-level single入力には、MSAから得た統計も結合される
    • アミノ酸タイプ
    • その位置のMSAアミノ酸分布
    • そのトークンのdeletion mean
  • ligand原子のようにMSAがないトークンでは、これらの値は0になる
  • このように作った s_inputs はprojectionを経て s_init となり、representation learning段階で更新される
  • pair representation z_init はtoken pairごとの関係を保存する3次元テンソルで、各 z_i,jc_z=128 次元のベクトルである
  • z_i,j の初期化には、s_is_j のprojection、relative positional encoding、ユーザーが指定したtoken間のbond情報が加えられる

表現学習: Template、MSA、Pairformer

  • representation learningはモデル計算の大半を占める trunk であり、token-level single表現 s とpair表現 z を改善することが目的である
  • single sequence representationは、単一のタンパク質配列だけでなく、構造内のすべての原子またはトークンをつなぎ合わせたsequenceを指す
  • Template Module

    • 各templateは線形projectionを通り、pair representation z の線形projectionと加え合わされる
    • 結合されたmatrixはPairformer Stackを通過する
    • 複数のtemplate結果は平均を取り、その後もう一度線形layerを通る
    • 最後の線形layerにはReLUが使われ、AF3でReLUが非線形として使われる数少ない箇所の1つである
  • MSA Module

    • MSA ModuleはAF2のEvoformerと非常によく似ており、MSA representation m とpair representation z を同時に改善する
    • MSAの全rowをすべて使うのではなくsubsamplingしたうえで、single representationのprojectionをMSAに加える
    • Outer Product Mean はMSA情報をpair representationに入れる演算である
      • token index i,j ごとに、すべてのevolutionary sequenceについて m_s,im_s,j のouter productを計算する
      • これをsequence全体で平均し、flattenしてprojectionした後、z_i,j に加える
      • モデル内でevolutionary sequence間の情報が共有される唯一の地点である
    • Row-wise gated self-attention using only pair bias はpair representationを使ってMSAを更新する
      • queryとkeyでattention scoreを作る代わりに、pair representation z をmatrixへprojectionしてtoken間のattention scoreとして使う
      • 各MSA rowに独立して適用されるため、この段階ではevolutionary sequence間の情報は共有されない
    • MSA moduleの最後では、triangle updateとtriangle attentionでpair representationを再び更新する

Pairformerとtriangle演算

  • TemplateとMSAでzを更新した後は、templateとMSAはもはや使わず、szのみがPairformerに入力される
  • Pairformerは48個のblockの反復を通じて最終的なs_trunkz_trunkを生成する
  • triangle演算の直感

    • triangle updateとtriangle attentionは、三角不等式の直感をモデルに反映しようとする構造である
    • pair tensorのz_i,jは物理的距離そのものではないが、token ijの関係を含んでいるため、i-jj-ki-kの3つの関係が互いに一貫するように更新する
    • 三角不等式はモデル内で直接強制されるのではなく、すべてのtriplet (i,j,k)を見ながらz_i,jを更新する方式で誘導される
    • zはdirected adjacency matrixのように見なせるため、outgoing edgeとincoming edgeの方向を分けて処理する
  • Triangle Updates

    • outgoing updateでは、各z_i,jを同じrowの別の要素z_i,kと3本目のedgeであるz_j,kを使って更新する
    • 実装上はzの3つのprojection abgを作り、row iとrow jのelement-wise multiplicationをkについて合算した後、gate gを適用する
    • incoming updateはrowとcolumnを入れ替えた形で、z_i,jを同じcolumnの別の要素z_k,jz_k,iを通じて更新する
  • Triangle Attention

    • triangle attentionは、2D matrixのrowとcolumnに独立したattentionを適用するaxial attentionにtriangleの原理を加えた形である
    • “starting node” caseでは、z_i,jz_i,kのquery-key比較にz_j,kをbiasとして加える
    • “ending node” caseでは、column基準で動作し、z_i,jz_k,iのattention scoreにz_k,jでbiasを加える
  • Single Attention with Pair Bias

    • triangle stepとtransition blockの後、single representation sは、更新されたpair representation zを使うsingle attention with pair biasで更新される
    • token-levelで動作するため、atom-levelで使っていたblock-wise sparse attentionではなくfull attentionを使う

構造予測: 原子座標を拡散でdenoising

  • 拡散モデルの基本方式

    • AF3は最終的な構造予測をatom-level diffusionで行う
    • diffusion modelは実データにrandom noiseを段階的に加え、モデルがどのnoiseが加えられたかを予測するよう学習する
    • inferenceでは完全なrandom noiseから始め、各stepでモデルが予測したnoiseを取り除きながらdenoised datapointを生成する
    • 条件付き拡散は、現在のnoisy generation、現在のtimestep表現、条件ベクトルを入力として受け取り、条件に合う結果を生成する
    • AF3でdenoisingの対象となるのは、すべての原子のx,y,z座標を含むmatrix xである
  • AF2のIPAの代わりに回転・並進拡張

    • AF3はAF2のInvariant Point Attentionを使わず、各timestepで予測中の複合体全体をランダムに回転・並進させる
    • この拡張により、どの回転や並進でも同じ構造として有効であることをモデルが学習し、AF2のIPAより単純なアプローチになっている
    • 回転は現在のgenerationの全原子座標の平均を中心に適用され、translationは各次元でN(0,1) Gaussianからサンプリングされる
    • 座標には小さなnoiseも加えられ、より多様なgenerationを誘導する
    • inferenceでは複数のgenerationをconfidence headでスコア化し、最も高いスコアのgenerationを返すことができる
  • Diffusion Moduleの4段階

    • 各denoising stepは複数のconditioning representationを使う
      • trunk出力 s_trunkz_trunk
      • input embedderで作られた初期表現 s_inputsc_inputs
    • diffusion過程はtoken空間とatom空間を行き来しながら4段階で構成される
        1. token-level conditioning tensorの準備
        1. atom-level conditioning tensorの準備、Atom Transformerの適用、token-levelへの集約
        1. token-level attentionの適用
        1. atom-level attentionで原子ごとのnoise updateを予測
    • token-level conditioningではz_trunkとrelative positional encodingを結合し、transition blockを通過させる
    • single representationにはs_inputss_trunkを結合し、diffusion timestepに応じたFourier embeddingを加える
    • atom-level段階では初期cpを現在のtoken-level representationで更新し、現在の座標xをdata varianceでスケーリングしてdimensionless coordinate rを作る
    • 最後のatom-level段階でlinear layerがqR^3にmappingし、すべての原子のcoordinate update r_updateを生成する
    • updateはdata varianceとnoise scheduleを考慮してx_updateに再スケーリングされた後、現在の座標x_lに適用される

損失関数とconfidence head

  • 全体のlossは3項の加重和である

L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence

  • L_distogram

    • L_distogramは、token-levelにおける予測distogramの正確性を評価する
    • 原子座標からtoken座標を作る際には、各tokenのcenter atom座標を使用する
    • distogram distanceはcategorical valueとして扱われ、予測distogramと実際のdistogramをcross entropyで比較する
  • L_diffusion

    • L_diffusionは、atom positionを対象とした複数項の重み付き和である
    • L_MSEは、center atomではなく全原子についてposition間のmean squared errorを計算し、DNA、RNA、ligand原子はupweightされる
    • L_bondは、protein-ligand bondに含まれるatom pairのbond length精度を高めるための追加のMSE項である
    • 初期のtraining stageではα_bond=0であり、後から導入される
    • L_smooth_LDDTは、local distance accuracyを滑らかで微分可能にしたlossである
      • thresholdは4Å、2Å、1Å、0.5Åの4つを使用する
      • nucleotide原子対は30Åより遠い場合は無視する
      • proteinまたはligand原子対は15Åより遠い場合は無視する
  • L_confidence

    • L_confidenceは、構造精度を直接高めるというより、モデルが自分の予測の精度を推定できるように学習させる
    • 4種類のconfidence metricに対応するlossで構成される
      • pLDDT: 近接した原子に対するlocal distance accuracy
      • PAE: token pairのpredicted alignment error
      • PDE: token pair間のpredicted distance error
      • experimentally resolved prediction: 各原子が実験構造でresolvedされていたかどうかの予測
    • 予測構造が不正確でPAEが高くても、モデルがPAEも高いと予測すれば、そのPAE lossは低くなりうる
    • confidence predictionはdiffusionの中間段階で生成される
    • confidence lossのgradientはconfidence prediction headのみを更新し、モデルの他の部分には影響しない

追加の学習手法と効率化

  • Recycling

    • AF3はAF2と同様にweight recyclingを使用する
    • モデルをより深くする代わりに、同じweightを何度も再利用してrepresentationを段階的に改善する
    • diffusionもinferenceでtimestep情報を使用し、同じweightを各timestepで再利用するため、recyclingを内包している
  • Cross-distillation

    • AF3は、自身が作成したsynthetic training dataだけでなく、AF2とAF-Multimerが作成したsynthetic dataも使用する
    • diffusionベースのgenerationへ移行した後、AF2で低信頼・無秩序領域を視覚的に区別できていた「spaghetti」形状が消えるという問題があった
    • AF2およびAF-MultimerのgenerationをAF3のtraining dataに含めることで、AF2が確信を持てない領域ではunfolded regionを出力する方法をAF3に学習させる
    • distillation datasetでは、AF2とAF-Multimerが扱えない核酸と低分子が除去される
    • 以前のモデルが予測構造を作成した後で元データとalignmentし、除去していた分子を再び追加する
    • 再追加した分子がatom clashを引き起こす場合は構造全体を除外し、モデルがclashを許容するよう学習してしまうのを防ぐ
  • Croppingとtraining stage

    • モデル自体には入力配列長に対する明示的な制限はないが、多くの演算がN_tokens^3で増加するため、memoryとcomputeの要求が大きくなる
    • 効率化のため、タンパク質はrandom cropされる
    • 複数chain間のinteractionをモデル化する必要があるため、cropにはchain群をまとめて含める必要がある
    • 3つのcropping手法が用いられる
      • contiguous cropping: 各chainから連続したアミノ酸sequenceを選択
      • spatial cropping: 基準原子までの距離に基づいてアミノ酸を選択
      • spatial interface cropping: binding interfaceの原子までの距離に基づいて選択
    • random crop 384で学習したモデルも、より長いsequenceに適用可能だが、長いsequenceを処理する能力を高めるため、より大きいsequence lengthで反復的にfine-tuningする
  • Clashingとbatch size

    • AF3のlossには、重なり合う原子に対するclash penaltyは含まれない
    • diffusion-based structure moduleは理論上、2つの原子を同じ位置に予測できるが、学習後はその問題は小さい
    • 生成構造のrankingにはclashing penaltyが使用される
    • diffusion processは複雑に見えるが、trunkより計算コストが低い
    • 学習効率のため、trunk以降のbatch sizeを拡張する
    • 各input structureはembeddingとtrunkを1回通過し、その後48個の独立したdata-augmented structureが並列に学習される

MLの観点から見たAF3の設計

  • Retrieval-Augmented Generationに類似した構造

    • AF3のMSAとtemplate検索は、言語モデルのRAGと似た性質を持つ
    • AlphaFold分野では、構造templateを使う方式はRAGという用語よりも以前からhomology modelingとして使われていた
    • AF3はAF2よりMSA処理の比重を減らしたが、MSAとtemplateは依然として含んでいる
    • ESMFoldのような一部のタンパク質予測モデルは、retrievalを取り除き、fully parametric inferenceを使用する
  • Pair-Bias Attention

    • AF2の主要構成要素だったPair-Bias Attentionは、AF3ではより広く使われている
    • query、key、valueは同じsourceから来るが、attention mapには別のsourceから来たbias termが加えられる
    • これはfull cross-attentionより軽量な情報共有方式である
    • pair representationがattention mapと自然に似ているため、この構造はタンパク質モデリングによく適している可能性がある
  • Self-supervised trainingの縮小

    • ESM系列モデルは、self-supervised pre-trainingによってMSA embeddingを置き換える方式で強みを示していた
    • AF2にはMSAのmasked tokenを予測する追加taskがあったが、AF3では削除された
    • AF3はMSA処理のcomputeを減らし、MSAに対するself-supervised language modeling pre-trainingを使用していない
    • 考えられる理由としては、massive pre-trainingがcompute使用の面で非効率だったか、小さなMSA moduleがpre-trained embeddingより優れていたか、あるいはアミノ酸・DNA/RNA・ligandが混在するhybrid atom-token構造とpre-trained embeddingの組み合わせが合わなかった可能性がある
  • ClassificationとRegressionの混合

    • AF3はAF2と同様にMSEとbinned classification lossを併用する
    • distogram binを1つだけ外しても、大きく外した場合と同じようにcreditが与えられない点がclassification lossの特徴である
    • この設計選択の根拠は明確ではないが、複数のMSE lossよりgradientが安定していた可能性がある
  • recurrent architectureを思わせる要素

    • AF3には一般的なtransformerより、recurrent networkを連想させる要素が多い
    • gatingはresidual streamにおける情報の流れを制御し、LSTMやGRUのgateに似ている
    • recyclingとdiffusionは同じweightを繰り返し適用し、予測を段階的に改善する
    • adaptive compute timeと同様に、反復更新は難しい入力により多くの処理を適用できる構造と関係している
    • AF2のablationではrecyclingの重要性が示されたが、gatingの重要性についてはあまり議論されてこなかった

まだコメントはありません。

まだコメントはありません。