カスタムモデル最適化のためのLlama-2ファインチューニング事例研究
(anyscale.com)- 汎用LLMが過剰になりがちな特化タスクでは、Llama-2を直接ファインチューニングすることで、より小さく低コストなモデルで品質・コスト・レイテンシを同時に改善できる
- Llama-2 13Bはファインチューニング後、ViGGO関数表現の精度が58%→98%、SQL生成が42%→89%、GSM8kが28%→47%へ向上した
- ViGGOやSQL生成のように出力形式が重要なタスクでは、小型のLlama-2モデルがGPT-4より良い結果を示した一方、数学推論ではGPT-4水準には到達しなかった
- 実験はRay Train、Ray Data、DeepSpeed、Accelerateベースのスクリプトで実施され、7B・13Bは16xA10G、70Bは32xA10Gで学習した
- 性能向上の鍵はモデルサイズよりもデータ品質と評価パイプラインであり、プロンプトエンジニアリングとファインチューニングのコスト・品質 trade-off をタスクごとに比較する必要がある
3つのタスクで見たファインチューニング効果
- GPT-4やClaude-2のような大規模な汎用モデルは迅速なプロトタイピングに有用だが、サポートチケットの要約・分類のように範囲の狭い要求には、コストと性能の面で過剰になりうる
- 実験では、Llama-2モデルを3つの実務寄りタスクに合わせてフルパラメータファインチューニングしたときの改善幅を比較した
- ViGGO: 非構造テキストから関数表現を抽出
- SQL-create-context: 自然言語とCREATE TABLEコンテキストからSQLを生成
- GSM8k: 小学生レベルの数学問題を解く
- Llama-2 13Bにおける精度変化は次の通り
- ViGGO関数表現: 58% → 98%
- SQL生成: 42% → 89%
- GSM8k: 28% → 47%
- ViGGOとSQL生成では小型のLlama-2モデルがGPT-4より良い結果を出したが、GSM8kのような数学推論タスクでは、ファインチューニング後でもGPT-4の性能には及ばなかった
ファインチューニング方式と学習インフラ
- 3つのタスクすべてで標準的なフルパラメータファインチューニングを使用
- 次トークン予測方式で学習
- モデルの全パラメータが勾配更新の対象
- LoRAや一部transformer blockを固定する方式は実験対象外
- 実験スクリプトはRay Train、Ray Data、DeepSpeed、Accelerateの上に構築された
- Llama-2 7B、13B、70Bの実行をサポート
- Ray TrainのTorchTrainerが複数ワーカープロセスとGPUリソースに学習ループを分散する
- データシャーディングはRay Trainが処理し、各ワーカーは
session.get_dataset_shard("train")、session.get_dataset_shard("valid")で割り当てられたデータ断片にアクセスする
- モデルシャーディングはDeepSpeed ZeRO stage 3とoptimizer state offloadingで処理
- モデル断片が複数ワーカーに分散されているため、チェックポイント保存のように完全なモデルへのアクセスが必要な場合は
accelerator.unwrap_model(model)でモデルを展開する必要がある
- モデル断片が複数ワーカーに分散されているため、チェックポイント保存のように完全なモデルへのアクセスが必要な場合は
- 計算資源は次の通り
- 7B・13B: 16xA10G
- 70B: 32xA10G、4台の
g5.48xlargeインスタンス - Rayを使えば、フルパラメータファインチューニングに必ずしもA100は必要ない
- 学習は最大10 epochまで実行し、検証セットでperplexityが最も低いチェックポイントを選択した
特殊トークンで入力・出力構造を固定
- ファインチューニングデータでは、命令文プロンプトの代わりに特殊トークンでタスク構造を表現した
- 例:
<START_Q>{question}<END_Q><START_A>{answer}<END_A>
- 例:
- 特殊トークンは、モデルが入力区間と出力区間を区別し、出力停止位置を明確に学習するのに役立つ
- 例では
<END_A>をstopping tokenとして定義し、タスク完了時に出力を停止させる
- 例では
- Llama tokenizerはデフォルトで32,000個のトークンIDを出力する
- 4つの特殊トークンを追加すると32,004個のIDを出力する
<START_Q>には32000、<END_Q>には32001というように新しいIDが割り当てられる
- スクリプトでは
tokenizer.add_tokens(special_tokens, special_tokens=True)で特殊トークンを追加し、model.resize_token_embeddings(len(tokenizer))で新しい学習パラメータを作成する
ViGGO: 非構造テキストを関数表現に変換
- ViGGOは元々、属性-値ベースの関数表現を自然言語テキストへ変換する英語データセットだが、実験では方向を逆にして、非構造テキストを構造化された関数表現へ変換した
- ドメインはビデオゲームの意見
- 結果表現はインデキシングや後続アプリケーションに利用できる
- モデルは文に対応する関数と属性値を生成する必要がある
- 関数候補には
inform、request、give_opinion、confirm、verify_attribute、suggest、request_explanation、recommend、request_attributeが含まれる - 属性候補には
name、release_year、esrb、genres、platforms、available_on_steam、has_linux_release、has_mac_release、specifier、rating、player_perspective、has_multiplayer、developer、exp_release_dateなどが含まれる
- 関数候補には
- 入力例
What's a really fast-paced game with multiplayer that you like to play?の期待出力はrequest(has_multiplayer[yes], specifier[fast-paced]) - 汎用モデルは意図した出力形式にうまく従えず、入力コンテキストが長いため、出力生成より入力処理時間の方が大きくなる問題があった
- このタスクは複雑な論理推論よりもパターン認識と基本的な言語理解が中心
- 必要な事実がすべて入力に含まれている grounded task である
- few-shotプロンプトが有効であることは、小型Llama-2モデルもファインチューニングで改善できるシグナルとして扱える
ViGGOの評価と結果
- 評価では完全一致の文字比較だけは使わない
- 出力関数が正しいかを確認
- 属性タイプが正しいかを確認
- 関数内の属性が定められた優先順位順に従っているかを確認
- GPT、Llama-2-chatのようなinstruction-followingモデルには、プロンプト内で属性順序ルールを明示していたため、そのルールに従う条件で評価した
- 評価速度を高めるため、Rayのbatch inference APIとAnyscaleのAviaryを併用した
- LLM生成と後処理をつなぎ、複数マシンへ分散する
- 7Bと13Bモデルはファインチューニング後に精度が大きく向上した
- GPT-4は属性優先順位を評価に含めると精度が大きく低下した
- ファインチューニング済みモデルは常に優先順位に従い、この制約を追加しても精度は変わらなかった
- ViGGOの結果は、構造化形式が必要なタスクでファインチューニングが安定かつ効率的な手段になりうることを示している
- 単なるregexやJSON形式合わせではなく、含める引数の判断と、その引数の順序まで守る必要があるタスクである
- 7B・13Bモデルで得た結果なので、GPT-4 endpoint呼び出しよりサービングコストを抑えられる可能性がある
SQL生成: 自然言語とテーブル文脈からクエリを作る
- SQL生成タスクは、自然言語クエリとSQL
CREATE TABLE文を入力として受け取り、実行可能なSQLクエリを生成するもの - 使用データセットb-mc2/sql-create-contextは、WikiSQLとSpiderを組み合わせたHugging Faceデータセット
- 各データポイントは自然言語クエリ、SQL
CREATE TABLE文、対応するSQLクエリで構成される - 全体で78,577件のデータポイントがある
- 各データポイントは自然言語クエリ、SQL
- データセットには正解SQLの問題があった
CREATE TABLEでは整数属性がVARCHARと表示されているのに、SQLクエリでは整数として扱われるケースが多かった- 整数属性を前提としたSQLクエリをすべて除去し、データセットを約70kから45kへ縮小した
- このタスクも自然言語をSQLという構造化表現へ変換する問題であり、ファインチューニングに適している
- ViGGOと異なり、正しい実行結果を返すSQLが複数ありうるため、より曖昧さがある
SQLの評価と結果
- SQL生成の評価では単純な文字列比較は不適切
- 文字単位の比較はfalse negativeを多く生みうる
- AST比較も変数名の順序のような要素に敏感になりうる
- 最も信頼できる方法は、ダミーデータセット上でコードを実行し、出力が同じか比較すること
- 実験ではOpenAIのGPT-3.5 endpointを使い、数百の例に対する単体テスト用ダミーテーブルを生成した
- GPT-3.5が質問、テーブルスキーマ、正解を見て、10件のデータポイントを持つダミーテーブルを作成する
sqlglot.executor.executeで正解SQLとモデルSQLを実行し、結果を比較する
- GPT-3.5生成のデータテーブル品質を確認するため、まず正解SQLを実行した
- 結果テーブルが空、または元のテーブルと同じ長さなら、その例は破棄した
- この過程でGPTが作成したデータテーブルの約**50%**がフィルタリングされた
- ファインチューニングしたLlama-2 7Bと13Bは70B-chatやGPT-4より高い性能を示した
- Llama chatモデルの典型的なエラーは、プロンプト指示に反してSQLを
<SQL>タグ内へ一貫して入れないことだった - この問題は7B・13B chatモデルで70Bより頻繁に見られた
- Llama chatモデルの典型的なエラーは、プロンプト指示に反してSQLを
- SQLデータセットの一部自然言語クエリは完全な英語ではなく、こうしたノイズがGPT-4の結果に影響した可能性がある
- ファインチューニングモデルは、データセット固有の癖にも素早く適応した
GSM8k: 構造学習より難しい数学推論
- GSM8kは数学推論と理解能力を評価する標準的な学術ベンチマーク
- 前の2タスクが主に構造学習だったのに対し、GSM8kはモデルが数学問題を解くための推論過程をどこまで改善できるかを見るタスク
- 例題は、4月に48個売り、5月にその半分を売った場合の総販売数を問う形式で、正答は途中計算とともに
#### 72形式で終わる - 現在のLLMは、最終回答だけを内部で計算して即座に出すよりも、思考過程を出力の一部として生成することで、その後のトークン生成が論理的過程に基づくようにする必要がある
- このタスクでは単純計算だけでなく、前提から中間結論を経て最終答えへ至る論理的な chain of thought が必要となる
GSM8kの評価方法とベースライン
- 評価には、モデル出力から最終解答を安定して抽出する方法が必要
- 汎用言語モデルは望ましい出力形式に一貫して従えないことがあり、自動評価が難しい
- そのためOpenAI function calling APIを使用
gpt-3.5-turbo-0613が他モデルの生成結果から最終的な整数解答を抽出するため、report_answer関数を呼び出す- たとえばモデルが “The answer is four” と返しても、
4としてパースできる
- この方法はデータセット正答で検証して有効性を確認したが、評価にOpenAIのトークンコストがかかる欠点がある
- ファインチューニング済みモデルは目標の回答パターンをすぐに学習し、誤答時でも出力構造が予測しやすい
- ファインチューニングモデルの評価は
#### {answer}正規表現で処理し、OpenAI endpointによる後処理を避けた
- ファインチューニングモデルの評価は
- ベースラインは次の通り
- 論文で公開されたbase pre-trainedモデルの8-shot prompting結果
- MetaがRLHFで汎用assistant化したLlama-2 chat-tuned派生モデル向けの、複数のprompt-engineeredテンプレート
GSM8kの結果と2段階ファインチューニング
- baseモデルのファインチューニングはGSM8k性能を一貫して高めたが、chat-tunedモデルを常に大きく上回るわけではなかった
- chatモデルはchat-tuning過程で数学の例題を学習している可能性があり、baseモデルより精度が高かった
- ファインチューニングモデルにプロンプトを加える方式が、常にbaseモデルより良い結果になるわけではない
- たとえばLlama-2-70B-chatは、8-shot例示プロンプトを加えたbaseモデルより低いことがある
- ファインチューニングモデルは8-shot prompted baseモデルより一貫して優れていた
- サービングコストの面では、ファインチューニングモデルが有利になりうる
- プロンプトベースの方式では、リクエストごとにプロンプトトークンのコストが発生する
- ファインチューニングモデルでは、実質的に質問トークン数だけがコストに反映される
- GSM8kの学習データは約8k件と比較的小さく、Llama-13Bの潜在能力を十分に引き出せないと判断された
- Llama-13B baseモデルをまずMathQAでファインチューニングし、その後GSM8kで再度ファインチューニングする2段階方式が追加の改善をもたらした
- GSM8kのみのファインチューニングではbase比で10ポイント改善
- MathQA後にGSM8kを用いた2段階ファインチューニングでは、初回ファインチューニング結果からさらに10ポイント、base比で合計20ポイント改善した
- MathQAは30,000件の質問/回答ペアで構成されるが、GSM8kよりノイズが多く構造も異なる
- 回答品質は低く、最終解答はmultiple choice形式
- それでも2段階ファインチューニングは、MathQAを活用してGSM8kの最終結果を改善するのに有効だった
実務適用で見るべき基準
- GPT-4やClaude-2のようなクローズドモデルは、プロトタイピングや初期の価値検証には強いが、本番LLMアプリ運用で常に十分とは限らない
- niche task向けのLLMファインチューニングは、プライバシーだけでなくレイテンシ、コスト、品質の面でも価値がありうる
- ViGGOやSQLの例では、品質面でもGPT-4より良い結果が出ている
- ファインチューニングで重要なのは、インフラの細かな実装よりもデータ収集と評価パイプライン構築
- 評価パイプラインは、複数の解法の trade-off をビジネス要件に合わせて比較する基盤になる
- 実験はAnyscaleのファインチューニングおよびサービングプラットフォームとAnyscale Endpointsを用いて実施された
- 同じプロセスは、自前データと自前クラウド上で再現できるよう、Ray上のAnyscaleファインチューニングおよびサービングソリューションとして構成されている
1件のコメント
Hacker Newsのコメント
数週間前のコーディング配信で、独自データセットでLlama 2をファインチューニングする内容をかなり扱っていて、Colabの単一GPUで進めていた
私の場合、データセットは自分のコードだった。
Fine-tuning Llama stream: https://www.youtube.com/watch?v=TYgtG2Th6fI&t=2282s
QLoRAのファインチューニング配信もいくつかあり、ソフトウェアエンジニア8年目で最近機械学習に移って独学してきた立場から概念を説明している
QloRa fine-tuning stream: https://www.youtube.com/watch?v=LitybCiLhSc&t=4584s
個人プロジェクトや、現在進めているAIベースのスタートアップでどう取り組んでいるかを、できるだけわかりやすく解きほぐそうとしている。最小のWeb開発向けLLMをファインチューニングするシリーズも反応が良さそうで、配信は1か月ほど続けており、今後さらに増やす予定
ファインチューニングしたモデルを分けて持つ考え方もよくわからない。Terraform LLM、SQL LLM、Python LLMを別々に用意すべきなのか、それとも単に「コード」LLMが1つあればいいのか気になる
実装の細部が多すぎて、意味のある用途でない限り敷居が高い。privateGPTはゆっくりだがその地点に近づいている気はする
他のチュートリアルがかなり飛ばしがちな部分だ。特に、安全性や正確性など異なる目標に応じてどう準備するのかが気になる
Llama 2でも同じ問題に悩まされている。欲しいテキストだけを出力させるのがほぼ不可能で、いつも応答の前後に何か付け足してくる
この問題を直せるプロンプト手法があるのか知りたい
airoborosはバッククォートや説明などを避けてコードだけを出力させるPLAINFORMATトークンをサポートしている
https://huggingface.co/TheBloke/airoboros-l2-70B-GPT4-2.0-GG...
保証したいなら、小さなデータセット、だいたい1,000件程度でファインチューニングして、そこから改善していくのがいちばん良い
私の用途は創作ライティングではなく、テキストから情報を抽出・要約する単純な作業だった。ベースモデルがすべての作業に合うとは限らない
content文字列やJSONの中に出力するようにプロンプトすればよいJSONなら開始と終了を識別できるので、JSON外の内容は削除できる
こういう記事が出てきてうれしい。オンラインではモデルのカスタマイズに関する議論が多すぎたが、この記事はノイズをかなりうまく取り除いている
評価方法論も気に入ったし、文章もよく書けているように見える
LoRAと量子化学習がもっと真面目に扱われていないのは不思議だ。はるかに安く、時間もかからず、かなり良いという証拠も多い
後から試す付加オプションのように後回しにする対象ではないと思う
NERに近いタスクが最も良い性能を出していたのを見てうれしい。ファインチューニングしたBERTモデルと比較するため、ちょうど似たようなテストをやろうとしていたところだった
このタスクの学習コストがどの程度なのか気になる
ブロックサイズは下げることもできるが、コードを変えないほうが簡単だったのでそのままにした。7Bは16xA10Gで1エポックあたり約15分、13Bは約25分かかった。したがってオンデマンド費用は1エポックあたり7Bが約**$7.2**、13Bが約**$12**。この値は学習に使った時間だけを基準にしており、クラスターの起動・停止時間は含まない
7Bと13Bには16xA10G、70Bには32xA10Gを4つのg5.48xlargeインスタンスに分散して使ったとある。Rayを使えば、この種のモデルのフルパラメータファインチューニングのためにA100を確保する必要がなく、同じプロセスを各タスクで繰り返している。GSM8kデータセットでは、コンテキスト長512、1エポックあたり有効トークン370万個を基準にした実行例を示している
最大10エポックまで学習し、検証セットで最小パープレキシティを示したチェックポイントを選択したとのこと
ひとつ難しいのは、十分に大きいカスタムデータセットを作るには、小さな軍隊のような人員か、非常に強力な既存モデルが必要だという点だ
結局OpenAIを使う可能性が高いが、OpenAIで別モデルの学習データを生成するのは利用規約違反だ。これをめぐって訴訟にまで至ったことがあるのか気になる。単に不公正だとして無視しているのだろうか?
最近NERの例をよく見るが、そういうタスクでなぜspaCyを使わないのか気になる
Anyscaleで働いている
このブログはかなり良い反響を得たようなので、Ray Summitに入れる予定だ: https://raysummit.anyscale.com/agenda
Ray Summitでどんな種類のコンテンツをもっと見たいか、アイデアがあれば教えてほしい
350万トークン基準で7Bは1エポック約14分、13Bは1エポック約26分と書かれている
7Bと13Bのどちらも、ヘッドノードとして最低1xg5.16xlarge、ワーカーノードとして15xg5.4xlargeが必要とのことだが、AWSでは費用がどの程度になるのか気になる
us-east-1で動かすなら、時間あたり約$30と見てよい
https://instances.vantage.sh/?selected=g5.16xlarge,g5.4xlarg...
M1 Ultra 64GBでLlama-2をローカルファインチューニングできるのか気になる。ほとんどがクラウドか、LinuxでNvidia CUDAを使う方法ばかりなので、参考になる資料があるとうれしい
学習はRunPodのクレジットを少し買ってやるつもりで、数十ドルもあればいけそうだ