
最近LLMをいじってるんですが、LLMのfinetuneをコードを自分で書いて実行するのはまだまだ不安です。 処理も重たくGPU必須なのでそれなりにコストがかかるのに、変にバグらせるとショックもでかいです。 なので、できればコマンドだけでぱぱっとできると嬉しい気がしてきました。
そんなときにtorchtuneというツールを見つけました。 どうやらLLMをコマンドと設定ファイルを書くだけでfinetuneできるようで、試してみることにしました。 ということで、今回はそのtorchtuneを使ってみたメモです。
torchtune
ドキュメントはこちらです。
特徴としてはこんな感じに書いてあります。
torchtune is a PyTorch library for easily authoring, fine-tuning and experimenting with LLMs. The library emphasizes 4 key aspects:
Simplicity and Extensibility. Native-PyTorch, componentized design and easy-to-reuse abstractions
Correctness. High bar on proving the correctness of components and recipes
Stability. PyTorch just works. So should torchtune
Democratizing LLM fine-tuning. Works out-of-the-box on different hardware
簡単に言うと、LLMのfinetuneを簡単に実行できるライブラリです。 設定自体はyamlに記述してコマンドを実行するだけ、そのyamlについてもサンプルがあるのでそれを書き換えて使うってツールです。
install
説明としてはこんな感じです。
コマンドとしてはこんな感じになります。
# Install stable version of PyTorch libraries using pip pip install torch torchvision torchao # Install torchtune pip install torchtune # try tune tune
動かしてみるとこんな感じです。
root@C.16123445:/$ pip install torch torchvision torchao
Requirement already satisfied: torch in /opt/conda/lib/python3.11/site-packages (2.5.1+cu124)
Requirement already satisfied: torchvision in /opt/conda/lib/python3.11/site-packages (0.20.1+cu124)
Collecting torchao
Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (14 kB)
Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch) (4.12.2)
Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch) (3.4.2)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from torch) (2024.10.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /opt/conda/lib/python3.11/site-packages (from torch) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /opt/conda/lib/python3.11/site-packages (from torch) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /opt/conda/lib/python3.11/site-packages (from torch) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /opt/conda/lib/python3.11/site-packages (from torch) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /opt/conda/lib/python3.11/site-packages (from torch) (12.3.1.170)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /opt/conda/lib/python3.11/site-packages (from torch) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /opt/conda/lib/python3.11/site-packages (from torch) (12.4.127)
Requirement already satisfied: triton==3.1.0 in /opt/conda/lib/python3.11/site-packages (from torch) (3.1.0)
Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.11/site-packages (from torch) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.11/site-packages (from sympy==1.13.1->torch) (1.3.0)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchvision) (2.1.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.11/site-packages (from torchvision) (10.2.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch) (3.0.2)
Downloading torchao-0.8.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (4.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 62.0 MB/s eta 0:00:00
Installing collected packages: torchao
Successfully installed torchao-0.8.0
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
root@C.16123445:/$ pip install torchtune
Collecting torchtune
Downloading torchtune-0.5.0-py3-none-any.whl.metadata (23 kB)
Collecting datasets (from torchtune)
Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting huggingface_hub[hf_transfer] (from torchtune)
Downloading huggingface_hub-0.27.1-py3-none-any.whl.metadata (13 kB)
Collecting safetensors (from torchtune)
Downloading safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting kagglehub (from torchtune)
Downloading kagglehub-0.3.6-py3-none-any.whl.metadata (30 kB)
Collecting sentencepiece (from torchtune)
Downloading sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting tiktoken (from torchtune)
Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting blobfile>=2 (from torchtune)
Downloading blobfile-3.0.0-py3-none-any.whl.metadata (15 kB)
Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from torchtune) (2.1.2)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.11/site-packages (from torchtune) (4.66.5)
Collecting omegaconf (from torchtune)
Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from torchtune) (6.1.0)
Requirement already satisfied: Pillow>=9.4.0 in /opt/conda/lib/python3.11/site-packages (from torchtune) (10.2.0)
Collecting pycryptodomex>=3.8 (from blobfile>=2->torchtune)
Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Requirement already satisfied: urllib3<3,>=1.25.3 in /opt/conda/lib/python3.11/site-packages (from blobfile>=2->torchtune) (2.2.3)
Collecting lxml>=4.9 (from blobfile>=2->torchtune)
Downloading lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Requirement already satisfied: filelock>=3.0 in /opt/conda/lib/python3.11/site-packages (from blobfile>=2->torchtune) (3.16.1)
Collecting pyarrow>=15.0.0 (from datasets->torchtune)
Downloading pyarrow-19.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets->torchtune)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets->torchtune)
Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.11/site-packages (from datasets->torchtune) (2.32.3)
Collecting xxhash (from datasets->torchtune)
Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets->torchtune)
Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets->torchtune)
Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets->torchtune)
Downloading aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets->torchtune) (24.1)
Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets->torchtune) (6.0.2)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.11/site-packages (from huggingface_hub[hf_transfer]->torchtune) (4.12.2)
Collecting hf-transfer>=0.1.4 (from huggingface_hub[hf_transfer]->torchtune)
Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Collecting antlr4-python3-runtime==4.9.* (from omegaconf->torchtune)
Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
Preparing metadata (setup.py) ... done
Collecting regex>=2022.1.18 (from tiktoken->torchtune)
Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets->torchtune)
Downloading aiohappyeyeballs-2.4.4-py3-none-any.whl.metadata (6.1 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets->torchtune)
Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets->torchtune) (24.2.0)
Collecting frozenlist>=1.1.1 (from aiohttp->datasets->torchtune)
Downloading frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->torchtune)
Downloading multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.0 kB)
Collecting propcache>=0.2.0 (from aiohttp->datasets->torchtune)
Downloading propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.2 kB)
Collecting yarl<2.0,>=1.17.0 (from aiohttp->datasets->torchtune)
Downloading yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (69 kB)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets->torchtune) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets->torchtune) (3.10)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.32.2->datasets->torchtune) (2024.8.30)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets->torchtune) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets->torchtune) (2024.2)
Collecting tzdata>=2022.7 (from pandas->datasets->torchtune)
Downloading tzdata-2024.2-py2.py3-none-any.whl.metadata (1.4 kB)
Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets->torchtune) (1.16.0)
Downloading torchtune-0.5.0-py3-none-any.whl (810 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 810.3/810.3 kB 34.6 MB/s eta 0:00:00
Downloading blobfile-3.0.0-py3-none-any.whl (75 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
Downloading kagglehub-0.3.6-py3-none-any.whl (51 kB)
Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
Downloading safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (461 kB)
Downloading sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 66.4 MB/s eta 0:00:00
Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 81.4 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
Downloading aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 75.9 MB/s eta 0:00:00
Downloading hf_transfer-0.1.9-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 92.0 MB/s eta 0:00:00
Downloading huggingface_hub-0.27.1-py3-none-any.whl (450 kB)
Downloading lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl (5.0 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 95.8 MB/s eta 0:00:00
Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
Downloading pyarrow-19.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (42.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.1/42.1 MB 103.5 MB/s eta 0:00:00
Downloading pycryptodomex-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 88.8 MB/s eta 0:00:00
Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (792 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 792.7/792.7 kB 72.3 MB/s eta 0:00:00
Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.1/13.1 MB 99.6 MB/s eta 0:00:00
Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
Downloading aiohappyeyeballs-2.4.4-py3-none-any.whl (14 kB)
Downloading aiosignal-1.3.2-py2.py3-none-any.whl (7.6 kB)
Downloading frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (274 kB)
Downloading multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)
Downloading propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (231 kB)
Downloading tzdata-2024.2-py2.py3-none-any.whl (346 kB)
Downloading yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (344 kB)
Building wheels for collected packages: antlr4-python3-runtime
Building wheel for antlr4-python3-runtime (setup.py) ... done
Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144554 sha256=e8a782328b03264b5537818566cf6f9612c56d14489ee28653824acbcf0cde69
Stored in directory: /root/.cache/pip/wheels/1a/97/32/461f837398029ad76911109f07047fde1d7b661a147c7c56d1
Successfully built antlr4-python3-runtime
Installing collected packages: sentencepiece, antlr4-python3-runtime, xxhash, tzdata, safetensors, regex, pycryptodomex, pyarrow, propcache, omegaconf, multidict, lxml, hf-transfer, fsspec, frozenlist, dill, aiohappyeyeballs, yarl, tiktoken, pandas, multiprocess, kagglehub, huggingface_hub, blobfile, aiosignal, aiohttp, datasets, torchtune
Attempting uninstall: fsspec
Found existing installation: fsspec 2024.10.0
Uninstalling fsspec-2024.10.0:
Successfully uninstalled fsspec-2024.10.0
Successfully installed aiohappyeyeballs-2.4.4 aiohttp-3.11.11 aiosignal-1.3.2 antlr4-python3-runtime-4.9.3 blobfile-3.0.0 datasets-3.2.0 dill-0.3.8 frozenlist-1.5.0 fsspec-2024.9.0 hf-transfer-0.1.9 huggingface_hub-0.27.1 kagglehub-0.3.6 lxml-5.3.0 multidict-6.1.0 multiprocess-0.70.16 omegaconf-2.3.0 pandas-2.2.3 propcache-0.2.1 pyarrow-19.0.0 pycryptodomex-3.21.0 regex-2024.11.6 safetensors-0.5.2 sentencepiece-0.2.0 tiktoken-0.8.0 torchtune-0.5.0 tzdata-2024.2 xxhash-3.5.0 yarl-1.18.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
root@C.16123445:/$ tune
usage: tune [-h] {download,ls,cp,run,validate} ...
Welcome to the torchtune CLI!
options:
-h, --help show this help message and exit
subcommands:
{download,ls,cp,run,validate}
download Download a model from the Hugging Face Hub or Kaggle Model Hub.
ls List all built-in recipes and configs
cp Copy a built-in recipe or config to a local path.
run Run a recipe. For distributed recipes, this supports all torchrun arguments.
validate Validate a config and ensure that it is well-formed.
root@C.16123445:/$
インストール自体は簡単にできますね。
基本的な使い方
初めて使うツールなのでチュートリアルに沿って、使い方を確認してみます。
まずはfinetuneするモデルのダウンロードはこんなコマンドでできます。
tune download meta-llama/Llama-2-7b-hf \ --output-dir /tmp/Llama-2-7b-hf \ --hf-token <ACCESS TOKEN>
finetuneの実行はこんな感じで、本当にコマンドを叩くだけで勝手にfinetuneされます。
tune run lora_finetune_single_device --config llama2/7B_lora_single_device epochs=1
上のコマンドでfinetuneをするにあたっての細かい設定は下記のようなyamlに記載されています。
詳しくはyamlを眺めてほしいんですが、finetuneに使うdatasetはこんな感じで定義されており、alpaca_cleaned_datasetを使ってfinetuneしようとしていることがわかります。
# Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset packed: False # True increases speed seed: null shuffle: True batch_size: 2
ログとしてはこんな感じにになるはずです。
root@C.16123445:/workspace$ tune run lora_finetune_single_device --config llama2/7B_lora_single_device epochs=1
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
adapter_checkpoint: null
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files:
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
model_type: LLAMA2
output_dir: /tmp/torchtune/llama2_7B/lora_single_device
recipe_checkpoint: null
compile: false
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/torchtune/llama2_7B/lora_single_device/logs
model:
_component_: torchtune.models.llama2.lora_llama2_7b
apply_lora_to_mlp: true
apply_lora_to_output: false
lora_alpha: 16
lora_attn_modules:
- q_proj
- v_proj
- output_proj
lora_dropout: 0.0
lora_rank: 8
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 0.0003
weight_decay: 0.01
output_dir: /tmp/torchtune/llama2_7B/lora_single_device
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /tmp/torchtune/llama2_7B/lora_single_device/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 5
with_flops: false
with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
max_seq_len: null
path: /tmp/Llama-2-7b-hf/tokenizer.model
DEBUG:torchtune.utils._logging:Setting manual seed to local seed 2311292768. Local seed is seed + rank = 2311292768 + 0
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/torchtune/llama2_7B/lora_single_device/logs/log_1737033758.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
GPU peak memory allocation: 13.03 GiB
GPU peak memory reserved: 13.04 GiB
GPU peak memory active: 13.03 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11.6k/11.6k [00:00<00:00, 104MB/s]
alpaca_data_cleaned.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44.3M/44.3M [00:00<00:00, 100MB/s]
Generating train split: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 51760/51760 [00:00<00:00, 157965.71 examples/s]
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
1|7|Loss: 1.6794713735580444: 0%|▏ | 7/3235 [00:15<1:58:37, 2.20s/it]
カスタムデータセットでfinetuneする
さて、本題のカスタムデータセットでfinetuneしてみます。 この辺を参考にしてやってみようと思います。
yamlで設定するとはいうものの、下記リンクのparameterにあるものをyamlに書いておくとそれを読み取って実行してくれるようです。
model download
先にモデルをダウンロードしておきます。今回はMeta-Llama-3.1-8B-Instructを使ってみます。
tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf-token <HF_TOKEN>
system prompt
finetuneする際のsystem promptはこんな感じにしてみます。
The following input consists of a prompt and two responses: model_a and model_b. Please select the more appropriate response to the prompt from either model_a or model_b.
dataset
今回は試しにこんな感じのデータセットを用意してやってみます。
[ { "conversations": [ { "from": "human", "value": "model_a: hoge \n model_b: fuga" }, { "from": "gpt", "value": "model_a" }, ] }, ]
configの用意
まずはconfigをコピーしてきます。
$ tune cp llama3_1/8B_lora_single_device custom_3_1_8B_lora_single_device.yaml
Copied file to custom_3_1_8B_lora_single_device.yaml
このconfigファイルをちょっと編集します。 datasetの部分だけこんな感じにしてみます。
# Dataset and Sampler dataset: _component_: torchtune.datasets.chat_dataset source: json data_files: dataset/finetune_dataset_small.json new_system_prompt: "The following input consists of a prompt and two responses: model_a and model_b. Please select the more appropriate response to the prompt from either model_a or model_b." split: train conversation_column: conversations conversation_style: sharegpt
これでfinetuneするとこんな感じで動作します。
$ tune run lora_finetune_single_device --config config/custom_3_1_8B_lora_single_device.yaml epochs=1
INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
batch_size: 2
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files:
- model-00001-of-00004.safetensors
- model-00002-of-00004.safetensors
- model-00003-of-00004.safetensors
- model-00004-of-00004.safetensors
model_type: LLAMA3
output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device
recipe_checkpoint: null
compile: false
dataset:
_component_: torchtune.datasets.chat_dataset
conversation_column: conversations
conversation_style: sharegpt
data_files: dataset/finetune_dataset_small.json
new_system_prompt: 'The following input consists of a prompt and two responses:
model_a and model_b. Please select the more appropriate response to the prompt
from either model_a or model_b.'
source: json
split: train
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
gradient_accumulation_steps: 8
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
max_steps_per_epoch: null
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/torchtune/llama3_1_8B/lora_single_device/logs
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
apply_lora_to_mlp: true
apply_lora_to_output: false
lora_alpha: 16
lora_attn_modules:
- q_proj
- v_proj
- output_proj
lora_dropout: 0.0
lora_rank: 8
optimizer:
_component_: torch.optim.AdamW
fused: true
lr: 0.0003
weight_decay: 0.01
output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device
profiler:
_component_: torchtune.training.setup_torch_profiler
active_steps: 2
cpu: true
cuda: true
enabled: false
num_cycles: 1
output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device/profiling_outputs
profile_memory: false
record_shapes: true
wait_steps: 5
warmup_steps: 3
with_flops: false
with_stack: false
resume_from_checkpoint: false
save_adapter_weights_only: false
seed: null
shuffle: true
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: null
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
DEBUG:torchtune.utils._logging:Setting manual seed to local seed 760940270. Local seed is seed + rank = 760940270 + 0
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Writing logs to /tmp/torchtune/llama3_1_8B/lora_single_device/logs/log_1737039608.txt
INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils._logging:Memory stats after model init:
GPU peak memory allocation: 15.06 GiB
GPU peak memory reserved: 15.18 GiB
GPU peak memory active: 15.06 GiB
INFO:torchtune.utils._logging:Tokenizer is initialized from file.
INFO:torchtune.utils._logging:Optimizer and loss are initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
0it [00:00, ?it/s]INFO:torchtune.utils._logging:Starting checkpoint save...
保存したmodelについては、下記を参考にして呼び出して使用することができます。
from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer #TODO: update it to your chosen epoch trained_model_path = "/tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0" # Define the model and adapter paths original_model_name = "/tmp/Meta-Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained(original_model_name) # huggingface will look for adapter_model.safetensors and adapter_config.json peft_model = PeftModel.from_pretrained(model, trained_model_path) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(original_model_name) # Function to generate text def generate_text(model, tokenizer, prompt, max_length=50): inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(**inputs, max_length=max_length) return tokenizer.decode(outputs[0], skip_special_tokens=True) prompt = "<なにかプロンプト>" print("Base model output:", generate_text(peft_model, tokenizer, prompt))
これでfinetuneしたモデルを実行できます。
vllmを使うとこんな感じです。
from vllm import LLM, SamplingParams def print_outputs(outputs): for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print("-" * 80) #TODO: update it to your chosen epoch llm = LLM( model="/tmp/torchtune/llama3_1_8B/lora_single_device/epoch_0", load_format="safetensors", kv_cache_dtype="auto", ) sampling_params = SamplingParams(max_tokens=16, temperature=0.5) conversation = [ {"role": "system", "content": "The following input consists of a prompt and two responses: model_a and model_b. Please select the more appropriate response to the prompt from either model_a or model_b."}, {"role": "user", "content": prompt}, ] outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False) print_outputs(outputs)
感想
以上、torchtuneを使ってみた記録でした。 カスタムデータセットを使ったやり方を紹介する文献が見当たらなかったので書いてみた次第です。
configをいじるだけなのでデータセットさえあれば簡単にfine tuningできそうで、素人にも簡単だなと感じました。