Consistency LLM: LLMを並列デコーダに変換して推論速度を3.5倍向上
(hao-ai-lab.github.io)Consistency Large Language Models (CLLM)
- LLMは従来、1つのトークンを連続的にデコードする方式で動作するシーケンシャルデコーダと見なされてきた
- 本ブログでは、事前学習済みLLMが効率的な並列デコーダへ容易に変換できることを示している
- Consistency Large Language Models(CLLM)は、推論ステップごとにn個のトークン系列を効率的にデコードし、推論レイテンシを低減できる新しい並列デコーダ群である
- CLLMは、人が単語を1つずつ話す前に頭の中で完全な文を組み立てる認知過程を模倣しており、事前学習済みLLMをファインチューニングするだけで効果的に学習できる
- CLLMは、ランダムに初期化されたn個のトークン系列を、可能な限り少ないステップで自己回帰(autoregressive, AR)デコーディング結果と同一に写像するよう並列デコーディングを行う方式で学習される
- 実験結果によれば、CLLMは既存のARデコーダと比べて2.4倍から3.4倍の生成速度向上を示し、Medusa2やEagleのような高速推論手法と同等またはそれ以上の性能を示した
- CLLMは、追加のメモリコストなしでこうした性能向上を達成できる
Jacobiデコーディングの背景と限界
- LLMはARデコーディング方式でトークンを1つずつ生成するため、長い応答では高いレイテンシが発生する
- JacobiデコーディングはJacobiおよびGauss-Seidelの非線形方程式解法に由来し、greedy samplingを用いたAR生成と等価であることが証明されている
- Jacobiデコーディングは、逐次生成過程をJacobi反復に基づくn変数のn本の非線形方程式系として再構成し、並列処理を可能にする
- 各反復ステップでは1個以上の正しいトークンを予測できるため、ARデコーディングを潜在的に高速化できる
- しかし実際には、ARで学習されたLLMは先行トークンに誤りがある場合、正しいトークンをほとんど生成できないため、ほとんどのJacobi反復ではn個のトークン系列に対して1つの修正しか得られず、より長いJacobi軌道(trajectory)を生んでしまう
- Lookaheadデコーディングやspeculativeデコーディングは、こうしたJacobiデコーディングの非効率性を緩和しようとするが、推論時に追加のメモリコストが発生する。一方、CLLMにはそれがない
CLLMの学習方法
- CLLMの学習は大きく、Jacobi軌道の準備と、consistency損失およびAR損失の最適化という2つの部分で構成される
- Jacobi軌道の準備段階では、応答全体の系列l個のトークンが生成されるまで、n個ずつ区切って順次Jacobiデコーディングを行い、各軌道で生成された系列を1つのデータ項目と見なす
- 学習時にはconsistency損失とAR損失を同時に最適化する。consistency損失は複数トークンを一度に予測することを保証し、AR損失はCLLMが対象LLMから逸脱しないようにして生成品質を維持する
- Global consistency(GC)損失は、Jacobi軌道上の任意の点と固定点の距離を最小化することで、CLLMがJacobi軌道のどの時点からでも固定点を予測するよう促す
- Local consistency(LC)損失は、Jacobi軌道の隣接する状態が同じ出力を出すよう誘導する
- AR損失は対象LLMの生成結果に基づく従来のAR損失を含み、その目的はCLLMが対象LLMの分布から外れないようにすることにある
実験結果
- 実験には、Spider(text-to-SQL)、Human-Eval(Pythonコード補完)、GSM8k(数学)などの特化ドメインタスクと、MT-benchのような幅広いオープンエンド対話チャレンジが含まれた
- CLLMは対象モデルに対して最も大きな速度向上を示し、推論時の追加コストなしでMedusa2と同等またはそれ以上の高速化を達成した
- MT-benchでは、CLLMはMedusa2と組み合わせた場合とほぼ同等の速度向上を達成したが、より高い適応性とメモリ効率を提供した
- CLLMのファインチューニングコストは適度な水準であり、データセット規模が大きい場合でも、Jacobi軌道生成にデータセットの約10%だけを使っても約2.5倍の速度向上が得られた
- CLLMはfast forwarding現象を通じて、複数の連続トークンを1回のJacobi反復で正しく予測できる
- CLLMは、先行トークンに誤りがあっても正しいトークンを先回りして予測し、変更されないまま維持するstationaryトークン能力を示す
- CLLMは学習を通じて、コロケーション(collocation)のような中核的な言語概念を習得し、これによりJacobi軌道のどの時点からでも構造を推測し、反復ステップを最小化するために複数の単語を同時に予測できるようになる
GN⁺の意見
-
CLLMは、既存LLMのARデコーディング方式が抱える長いレイテンシの問題を、Jacobiデコーディングを活用して効果的に解決したものに見える。特に、追加のメモリコストなしで並列化デコーディングにより高速化を達成した点は印象的だ
-
CLLMの学習方法は、既存LLMをconsistency損失によってファインチューニングする比較的シンプルなものに見えるが、それによって言語の重要な特性の1つであるコロケーション(collocation)を学習し、並列デコーディング性能を大きく向上させた点に意義があるように思える
-
ただし、CLLMはgreedy samplingを前提としているため、より多様なデコーディング戦略でもうまく動作するかどうかは追加研究が必要に思える。また、現時点では英語に限定された実験結果であり、多言語への一般化可能性についても検証が必要だろう
-
CLLMは、LLMの応答速度を高速化する方法として実用的なアプローチだと思われる。Web検索やチャットボットなど、リアルタイム性が求められるタスクにうまく適用できそうだ
-
個人的には、CLLMのconsistency学習方式がGPTなどのLLMだけでなく、画像生成モデルや音声合成モデルなど他の生成モデルにも適用できるのではないかと期待している。CLLMのアイデアが今後、さまざまな生成モデルの効率向上に寄与することを願う
1件のコメント
Hacker Newsの意見