Re:ゼロから始めるML生活

どちらかといえばエミリア派です

torchtuneを使ってカスタムデータセットでLlama 3.1をfinetuneする

最近LLMをいじってるんですが、LLMのfinetuneをコードを自分で書いて実行するのはまだまだ不安です。 処理も重たくGPU必須なのでそれなりにコストがかかるのに、変にバグらせるとショックもでかいです。 なので、できればコマンドだけでぱぱっとできると嬉しい気がしてきました。

そんなときにtorchtuneというツールを見つけました。 どうやらLLMをコマンドと設定ファイルを書くだけでfinetuneできるようで、試してみることにしました。 ということで、今回はそのtorchtuneを使ってみたメモです。

torchtune

ドキュメントはこちらです。

pytorch.org

特徴としてはこんな感じに書いてあります。

torchtune is a PyTorch library for easily authoring, fine-tuning and experimenting with LLMs. The library emphasizes 4 key aspects:

Simplicity and Extensibility. Native-PyTorch, componentized design and easy-to-reuse abstractions

Correctness. High bar on proving the correctness of components and recipes

Stability. PyTorch just works. So should torchtune

Democratizing LLM fine-tuning. Works out-of-the-box on different hardware

簡単に言うと、LLMのfinetuneを簡単に実行できるライブラリです。 設定自体はyamlに記述してコマンドを実行するだけ、そのyamlについてもサンプルがあるのでそれを書き換えて使うってツールです。

install

説明としてはこんな感じです。

pytorch.org

コマンドとしてはこんな感じになります。

# Install stable version of PyTorch libraries using pip
pip install torch torchvision torchao

# Install torchtune
pip install torchtune

# try tune
tune

動かしてみるとこんな感じです。

root@C.16123445:/$ pip install torch torchvision torchao
Requirement already satisfied: torch in /opt/conda/lib/python3.11/site-packages (2.5.1+cu124)
Requirement already satisfied: torchvision in /opt/conda/lib/python3.11/site-packages (0.20.1+cu124)
Collecting torchao
  Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (14 kB)
Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch) (4.12.2)
Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch) (3.4.2)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch) (2024.10.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/conda/lib/python3.11/site-packages (from torch) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/conda/lib/python3.11/site-packages (from torch) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/conda/lib/python3.11/site-packages (from torch) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/conda/lib/python3.11/site-packages (from torch) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/conda/lib/python3.11/site-packages (from torch) (12.3.1.170)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/conda/lib/python3.11/site-packages (from torch) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: triton==3.1.0 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.0)
Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.11/site-packages (from torch) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.11/site-packages (from sympy==1.13.1->torch) (1.3.0)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.1.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.11/site-packages (from torchvision) (10.2.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch) (3.0.2)
Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (4.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 62.0 MB/s eta 0:00:00
Installing collected packages: torchao
Successfully installed torchao-0.8.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
root@C.16123445:/$ pip install torchtune
Collecting torchtune
  Downloading torchtune-0.5.0-py3-none-any.whl.metadata (23 kB)
Collecting datasets (from torchtune)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting huggingface_hub[hf_transfer] (from torchtune)
  Downloading huggingface_hub-0.27.1-py3-none-any.whl.metadata (13 kB)
Collecting safetensors (from torchtune)
  Downloading safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting kagglehub (from torchtune)
  Downloading kagglehub-0.3.6-py3-none-any.whl.metadata (30 kB)
Collecting sentencepiece (from torchtune)
  Downloading sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting tiktoken (from torchtune)
  Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting blobfile>=2 (from torchtune)
  Downloading blobfile-3.0.0-py3-none-any.whl.metadata (15 kB)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchtune) (2.1.2)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.11/site-packages (from torchtune) (4.66.5)
Collecting omegaconf (from torchtune)
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from torchtune) (6.1.0)
Requirement already satisfied: Pillow>=9.4.0 in /opt/conda/lib/python3.11/site-packages (from torchtune) (10.2.0)
Collecting pycryptodomex>=3.8 (from blobfile>=2->torchtune)
  Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Requirement already satisfied: urllib3<3,>=1.25.3 in /opt/conda/lib/python3.11/site-packages (from blobfile>=2->torchtune) (2.2.3)
Collecting lxml>=4.9 (from blobfile>=2->torchtune)
  Downloading lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Requirement already satisfied: filelock>=3.0 in /opt/conda/lib/python3.11/site-packages (from blobfile>=2->torchtune) (3.16.1)
Collecting pyarrow>=15.0.0 (from datasets->torchtune)
  Downloading pyarrow-19.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets->torchtune)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets->torchtune)
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.11/site-packages (from datasets->torchtune) (2.32.3)
Collecting xxhash (from datasets->torchtune)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets->torchtune)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets->torchtune)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets->torchtune)
  Downloading aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets->torchtune) (24.1)
Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets->torchtune) (6.0.2)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.11/site-packages (from huggingface_hub[hf_transfer]->torchtune) (4.12.2)
Collecting hf-transfer>=0.1.4 (from huggingface_hub[hf_transfer]->torchtune)
  Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf->torchtune)
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
  Preparing metadata (setup.py) ... done
Collecting regex>=2022.1.18 (from tiktoken->torchtune)
  Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets->torchtune)
  Downloading aiohappyeyeballs-2.4.4-py3-none-any.whl.metadata (6.1 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets->torchtune)
  Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets->torchtune) (24.2.0)
Collecting frozenlist>=1.1.1 (from aiohttp->datasets->torchtune)
  Downloading frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->torchtune)
  Downloading multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.0 kB)
Collecting propcache>=0.2.0 (from aiohttp->datasets->torchtune)
  Downloading propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.2 kB)
Collecting yarl<2.0,>=1.17.0 (from aiohttp->datasets->torchtune)
  Downloading yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (69 kB)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets->torchtune) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets->torchtune) (3.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets->torchtune) (2024.8.30)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets->torchtune) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets->torchtune) (2024.2)
Collecting tzdata>=2022.7 (from pandas->datasets->torchtune)
  Downloading tzdata-2024.2-py2.py3-none-any.whl.metadata (1.4 kB)
Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets->torchtune) (1.16.0)
Downloading torchtune-0.5.0-py3-none-any.whl (810 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 810.3/810.3 kB 34.6 MB/s eta 0:00:00
Downloading blobfile-3.0.0-py3-none-any.whl (75 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
Downloading kagglehub-0.3.6-py3-none-any.whl (51 kB)
Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
Downloading safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (461 kB)
Downloading sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 66.4 MB/s eta 0:00:00
Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 81.4 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
Downloading aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 75.9 MB/s eta 0:00:00
Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 92.0 MB/s eta 0:00:00
Downloading huggingface_hub-0.27.1-py3-none-any.whl (450 kB)
Downloading lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl (5.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 95.8 MB/s eta 0:00:00
Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
Downloading pyarrow-19.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (42.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.1/42.1 MB 103.5 MB/s eta 0:00:00
Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 88.8 MB/s eta 0:00:00
Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (792 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 792.7/792.7 kB 72.3 MB/s eta 0:00:00
Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.1/13.1 MB 99.6 MB/s eta 0:00:00
Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
Downloading aiohappyeyeballs-2.4.4-py3-none-any.whl (14 kB)
Downloading aiosignal-1.3.2-py2.py3-none-any.whl (7.6 kB)
Downloading frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (274 kB)
Downloading multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)
Downloading propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (231 kB)
Downloading tzdata-2024.2-py2.py3-none-any.whl (346 kB)
Downloading yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (344 kB)
Building wheels for collected packages: antlr4-python3-runtime
  Building wheel for antlr4-python3-runtime (setup.py) ... done
  Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144554 sha256=e8a782328b03264b5537818566cf6f9612c56d14489ee28653824acbcf0cde69
  Stored in directory: /root/.cache/pip/wheels/1a/97/32/461f837398029ad76911109f07047fde1d7b661a147c7c56d1
Successfully built antlr4-python3-runtime
Installing collected packages: sentencepiece, antlr4-python3-runtime, xxhash, tzdata, safetensors, regex, pycryptodomex, pyarrow, propcache, omegaconf, multidict, lxml, hf-transfer, fsspec, frozenlist, dill, aiohappyeyeballs, yarl, tiktoken, pandas, multiprocess, kagglehub, huggingface_hub, blobfile, aiosignal, aiohttp, datasets, torchtune
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2024.10.0
    Uninstalling fsspec-2024.10.0:
      Successfully uninstalled fsspec-2024.10.0
Successfully installed aiohappyeyeballs-2.4.4 aiohttp-3.11.11 aiosignal-1.3.2 antlr4-python3-runtime-4.9.3 blobfile-3.0.0 datasets-3.2.0 dill-0.3.8 frozenlist-1.5.0 fsspec-2024.9.0 hf-transfer-0.1.9 huggingface_hub-0.27.1 kagglehub-0.3.6 lxml-5.3.0 multidict-6.1.0 multiprocess-0.70.16 omegaconf-2.3.0 pandas-2.2.3 propcache-0.2.1 pyarrow-19.0.0 pycryptodomex-3.21.0 regex-2024.11.6 safetensors-0.5.2 sentencepiece-0.2.0 tiktoken-0.8.0 torchtune-0.5.0 tzdata-2024.2 xxhash-3.5.0 yarl-1.18.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
root@C.16123445:/$ tune
usage: tune [-h] {download,ls,cp,run,validate} ...

Welcome to the torchtune CLI!

options:
  -h, --help            show this help message and exit

subcommands:
  {download,ls,cp,run,validate}
    download            Download a model from the Hugging Face Hub or Kaggle Model Hub.
    ls                  List all built-in recipes and configs
    cp                  Copy a built-in recipe or config to a local path.
    run                 Run a recipe. For distributed recipes, this supports all torchrun arguments.
    validate            Validate a config and ensure that it is well-formed.
root@C.16123445:/$ 

インストール自体は簡単にできますね。

基本的な使い方

初めて使うツールなのでチュートリアルに沿って、使い方を確認してみます。

まずはfinetuneするモデルのダウンロードはこんなコマンドでできます。

tune download meta-llama/Llama-2-7b-hf \
  --output-dir /tmp/Llama-2-7b-hf \
  --hf-token <ACCESS TOKEN>

finetuneの実行はこんな感じで、本当にコマンドを叩くだけで勝手にfinetuneされます。

tune run lora_finetune_single_device --config llama2/7B_lora_single_device epochs=1

上のコマンドでfinetuneをするにあたっての細かい設定は下記のようなyamlに記載されています。

github.com

詳しくはyamlを眺めてほしいんですが、finetuneに使うdatasetはこんな感じで定義されており、alpaca_cleaned_datasetを使ってfinetuneしようとしていることがわかります。

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: False  # True increases speed
seed: null
shuffle: True
batch_size: 2

ログとしてはこんな感じにになるはずです。

root@C.16123445:/workspace$ tune run lora_finetune_single_device --config llama2/7B_lora_single_device epochs=1
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  adapter_checkpoint: null
  checkpoint_dir: /tmp/Llama-2-7b-hf
  checkpoint_files:
  - pytorch_model-00001-of-00002.bin
  - pytorch_model-00002-of-00002.bin
  model_type: LLAMA2
  output_dir: /tmp/torchtune/llama2_7B/lora_single_device
  recipe_checkpoint: null
compile: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/torchtune/llama2_7B/lora_single_device/logs
model:
  _component_: torchtune.models.llama2.lora_llama2_7b
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_alpha: 16
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  lora_dropout: 0.0
  lora_rank: 8
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 0.0003
  weight_decay: 0.01
output_dir: /tmp/torchtune/llama2_7B/lora_single_device
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama2_7B/lora_single_device/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 5
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  max_seq_len: null
  path: /tmp/Llama-2-7b-hf/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 2311292768. Local seed is seed + rank = 2311292768 + 0
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/torchtune/llama2_7B/lora_single_device/logs/log_1737033758.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 13.03 GiB
        GPU peak memory reserved: 13.04 GiB
        GPU peak memory active: 13.03 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11.6k/11.6k [00:00<00:00, 104MB/s]
alpaca_data_cleaned.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44.3M/44.3M [00:00<00:00, 100MB/s]
Generating train split: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 51760/51760 [00:00<00:00, 157965.71 examples/s]
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|7|Loss: 1.6794713735580444:   0%|▏                                                                                                            | 7/3235 [00:15<1:58:37,  2.20s/it]

カスタムデータセットでfinetuneする

さて、本題のカスタムデータセットでfinetuneしてみます。 この辺を参考にしてやってみようと思います。

pytorch.org

yamlで設定するとはいうものの、下記リンクのparameterにあるものをyamlに書いておくとそれを読み取って実行してくれるようです。

pytorch.org

model download

先にモデルをダウンロードしておきます。今回はMeta-Llama-3.1-8B-Instructを使ってみます。

tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf-token <HF_TOKEN>

system prompt

finetuneする際のsystem promptはこんな感じにしてみます。

The following input consists of a prompt and two responses: model_a and model_b. Please select the more appropriate response to the prompt from either model_a or model_b.

dataset

今回は試しにこんな感じのデータセットを用意してやってみます。

[
    {
        "conversations": [
            {
                "from": "human",
                "value": "model_a:  hoge \n model_b: fuga"
            },
            {
                "from": "gpt",
                "value": "model_a"
            },
        ]
    },
]

configの用意

まずはconfigをコピーしてきます。

$ tune cp llama3_1/8B_lora_single_device custom_3_1_8B_lora_single_device.yaml
Copied file to custom_3_1_8B_lora_single_device.yaml

このconfigファイルをちょっと編集します。 datasetの部分だけこんな感じにしてみます。

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.chat_dataset
  source: json
  data_files: dataset/finetune_dataset_small.json
  new_system_prompt: "The following input consists of a prompt and two responses: model_a and model_b. Please select the more appropriate response to the prompt from either model_a or model_b."
  split: train
  conversation_column: conversations
  conversation_style: sharegpt

これでfinetuneするとこんな感じで動作します。

$ tune run lora_finetune_single_device --config config/custom_3_1_8B_lora_single_device.yaml epochs=1
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:

batch_size: 2
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  model_type: LLAMA3
  output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device
  recipe_checkpoint: null
compile: false
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: sharegpt
  data_files: dataset/finetune_dataset_small.json
  new_system_prompt: 'The following input consists of a prompt and two responses:
    model_a and model_b. Please select the more appropriate response to the prompt
    from either model_a or model_b.'
  source: json
  split: train
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/torchtune/llama3_1_8B/lora_single_device/logs
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_alpha: 16
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  lora_dropout: 0.0
  lora_rank: 8
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 0.0003
  weight_decay: 0.01
output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: null
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

DEBUG:torchtune.utils._logging:Setting manual seed to local seed 760940270. Local seed is seed + rank = 760940270 + 0
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/torchtune/llama3_1_8B/lora_single_device/logs/log_1737039608.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
        GPU peak memory allocation: 15.06 GiB
        GPU peak memory reserved: 15.18 GiB
        GPU peak memory active: 15.06 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
0it [00:00, ?it/s]INFO:torchtune.utils._logging:Starting checkpoint save...

保存したmodelについては、下記を参考にして呼び出して使用することができます。

pytorch.org

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

#TODO: update it to your chosen epoch
trained_model_path = "/tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0"

# Define the model and adapter paths
original_model_name = "/tmp/Meta-Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(original_model_name)

# huggingface will look for adapter_model.safetensors and adapter_config.json
peft_model = PeftModel.from_pretrained(model, trained_model_path)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(original_model_name)

# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

prompt = "<なにかプロンプト>"
print("Base model output:", generate_text(peft_model, tokenizer, prompt))

これでfinetuneしたモデルを実行できます。

vllmを使うとこんな感じです。

from vllm import LLM, SamplingParams

def print_outputs(outputs):
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    print("-" * 80)

#TODO: update it to your chosen epoch
llm = LLM(
    model="/tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0",
    load_format="safetensors",
    kv_cache_dtype="auto",
)
sampling_params = SamplingParams(max_tokens=16, temperature=0.5)

conversation = [
    {"role": "system", "content": "The following input consists of a prompt and two responses: model_a and model_b. Please select the more appropriate response to the prompt from either model_a or model_b."},
    {"role": "user", "content": prompt},
]
outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False)
print_outputs(outputs)

感想

以上、torchtuneを使ってみた記録でした。 カスタムデータセットを使ったやり方を紹介する文献が見当たらなかったので書いてみた次第です。

configをいじるだけなのでデータセットさえあれば簡単にfine tuningできそうで、素人にも簡単だなと感じました。