Skip to main content
In this experiment, we will fine-tune a causal language model using QLoRA and the SFTTrainer from trl. We will use Llama3.1 as our large language model (LLM) and train it on the openassistant-guanaco dataset but you can easily choose another model or dataset from the HuggingFace hub.
1

Connect to GitHub (if needed)

If you haven’t already connected FlexAI to GitHub, you’ll need to set up a code registry connection:
flexai code-registry connect
This will allow FlexAI to pull repositories directly from GitHub using the -u flag in training commands.
2

Preparing the Dataset

In this experiment, we will use a pre-processed version of the the openassistant-guanaco dataset that has been set up for the Llama3.1 model.
DATASET_NAME=llama-tokenized-oag && curl -L -o ${DATASET_NAME}.zip "https://bucket-docs-samples-99b3a05.s3.eu-west-1.amazonaws.com/${DATASET_NAME}.zip" && unzip ${DATASET_NAME}.zip && rm ${DATASET_NAME}.zip
If you’d like to reproduce the pre-processing steps yourself to use a different dataset or simply to learn more about the process, you can refer to the Manual Dataset Pre-processing section below.
Next, push the contents of the llama-tokenized-oag/ directory as a new FlexAI dataset:
flexai dataset push llama-tokenized-oag --file llama-tokenized-oag

Create Secrets

To access the Llama-3.1-8B model, you need to accept the license with your HuggingFace account. To be authenticated within your code, you will use your HuggingFace Token. Use the flexai secret create command to store your HuggingFace Token as a secret. Replace <HF_AUTH_TOKEN_SECRET_NAME> with your desired name for the secret:
flexai secret create <HF_AUTH_TOKEN_SECRET_NAME>
Then paste your HuggingFace Token API key value.

Training

To start the Training Job, run the following command:
flexai training run llama3-1-training-ddp --repository-url https://github.com/flexaihq/blueprints --requirements-path code/causal-language-modeling-qlora/requirements.txt \
  --dataset llama-tokenized-oag \
  --secret HF_TOKEN=<HF_AUTH_TOKEN_SECRET_NAME> --secret WANDB_API_KEY=<WANDB_API_KEY_SECRET_NAME> --env WANDB_PROJECT=<YOUR_PROJECT_NAME> \
  --nodes 1 --accels 2 \
  -- code/causal-language-modeling-qlora/train.py \
    --model_name_or_path meta-llama/Meta-Llama-3.1-8B \
    --dataset_name timdettmers/openassistant-guanaco \
    --tokenized_dataset_load_dir /input/llama-tokenized-oag \
    --dataset_text_field text \
    --load_in_4bit \
    --use_peft \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --output_dir /output-checkpoint \
    --log_level info

Optional Extra Steps

Manual Dataset Pre-processing

If you’d prefer to perform the dataset pre-processing step yourself, you can follow these instructions. You can run these in a FlexAI Interactive Session or in a local env (e.g. pipenv install --python 3.10), if you have hardware that’s capable of doing inference.

Clone this repository

If you haven’t already, clone this repository on your host machine:
git clone https://github.com/flexaihq/blueprints.git blueprints --depth 1 --branch main && cd blueprints

Install the dependencies

Depending on your environment, you might need to install - if not already - the experiments’ dependencies by running:
pip install -r requirements.txt

Dataset preparation

Prepare the dataset by running the following command:
python dataset/prepare_save_sft.py \
  --model_name_or_path meta-llama/Meta-Llama-3.1-8B \
  --dataset_name timdettmers/openassistant-guanaco \
  --dataset_text_field text \
  --log_level info \
  --tokenized_dataset_save_dir llama-tokenized-oag \
  --output_dir ./.sft.output # This argument is not used but is required to use the SFT argument parser.
The prepared dataset will be saved to the llama-tokenized-oag/ directory.

Code

code/causal-language-modeling-qlora/train.py

# Adapted from: https://github.com/huggingface/trl/blob/2cad48d511fab99ac0c4b327195523a575afcad3/examples/scripts/sft.py
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
from datasets import load_dataset
from transformers import AutoTokenizer
from trl import (
    ModelConfig,
    SFTConfig,
    SFTTrainer,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.commands.cli_utils import SFTScriptArguments, TrlParser

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from utils.experiment_tracking import set_wandb


@dataclass
class AdditionalArguments:
    """
    Additional arguments that are not part of the TRL arguments.
    """

    tokenized_dataset_load_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Directory to load the tokenized dataset using the "
                "`datasets.load_from_disk` method. No tokenization will be done"
                " the dataset will be loaded from this directory."
            )
        },
    )


if __name__ == "__main__":
    parser = TrlParser(
        (SFTScriptArguments, SFTConfig, ModelConfig, AdditionalArguments)
    )
    args, training_args, model_config, additional_args = parser.parse_args_and_config()
    set_wandb(training_args)
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=model_config.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
        trust_remote_code=model_config.trust_remote_code,
        use_fast=True,
    )
    tokenizer.pad_token = tokenizer.eos_token

    if additional_args.tokenized_dataset_load_dir:
        dataset = datasets.load_from_disk(additional_args.tokenized_dataset_load_dir)
        skip_prepare_dataset = True
        if training_args.dataset_kwargs is None:
            training_args.dataset_kwargs = {}
        training_args.dataset_kwargs["skip_prepare_dataset"] = True
    else:
        dataset = load_dataset(args.dataset_name)

    trainer = SFTTrainer(
        model=model_config.model_name_or_path,
        args=training_args,
        train_dataset=dataset[args.dataset_train_split],
        eval_dataset=dataset[args.dataset_test_split],
        tokenizer=tokenizer,
        peft_config=get_peft_config(model_config),
    )

    trainer.train()
    trainer.save_model(training_args.output_dir)

code/causal-language-modeling-qlora/requirements.txt

accelerate>=1.8.1
bitsandbytes>=0.43.2
datasets>=2.21.0
evaluate>=0.4.3
peft>=0.15.0
transformers>=4.43.3
trl==0.11.4
wandb>=0.18.1