> ## Documentation Index
> Fetch the complete documentation index at: https://docs.flex.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# Fine-Tune Llama 3.1 with QLoRA (4-bit) on FlexAI

> Fine-tune Llama 3.1 using QLoRA 4-bit quantization on FlexAI. Cut memory requirements by 75%, launch in under 60s, managed checkpoints included.

In this experiment, we will fine-tune a causal language model using QLoRA and the `SFTTrainer` from `trl`.

We will use `Llama3.1` as our large language model (LLM) and train it on the `openassistant-guanaco` dataset but you can easily choose another model or dataset from the HuggingFace hub.

<Steps>
  <Step title="Connect to GitHub (if needed)">
    If you haven't already connected FlexAI to GitHub, you'll need to set up a code registry connection:

    ```bash theme={null}
    flexai code-registry connect
    ```

    This will allow FlexAI to pull repositories directly from GitHub using the `-u` flag in training commands.
  </Step>

  <Step title="Preparing the Dataset">
    In this experiment, we will use a pre-processed version of the the `openassistant-guanaco` dataset that has been set up for the `Llama3.1` model.

    ```bash theme={null}
    DATASET_NAME=llama-tokenized-oag && curl -L -o ${DATASET_NAME}.zip "https://bucket-docs-samples-99b3a05.s3.eu-west-1.amazonaws.com/${DATASET_NAME}.zip" && unzip ${DATASET_NAME}.zip && rm ${DATASET_NAME}.zip
    ```

    > If you'd like to reproduce the pre-processing steps yourself to use a different dataset or simply to learn more about the process, you can refer to the [Manual Dataset Pre-processing](#manual-dataset-pre-processing) section below.

    Next, push the contents of the `llama-tokenized-oag/` directory as a new FlexAI dataset:

    ```bash theme={null}
    flexai dataset push llama-tokenized-oag --file llama-tokenized-oag
    ```
  </Step>
</Steps>

## Create Secrets

To access the Llama-3.1-8B model, you need to [accept the license](https://huggingface.co/meta-llama/Llama-3.1-8B) with your HuggingFace account.

To be authenticated within your code, you will use your *HuggingFace Token*.

Use the [`flexai secret create` command](https://docs.flex.ai/cli/commands/secret/) to store your *HuggingFace Token* as a secret. Replace `<HF_AUTH_TOKEN_SECRET_NAME>` with your desired name for the secret:

```bash theme={null}
flexai secret create <HF_AUTH_TOKEN_SECRET_NAME>
```

Then paste your *HuggingFace Token* API key value.

## Training

To start the Training Job, run the following command:

```bash theme={null}
flexai training run llama3-1-training-ddp --repository-url https://github.com/flexaihq/blueprints --requirements-path code/causal-language-modeling-qlora/requirements.txt \
  --dataset llama-tokenized-oag \
  --secret HF_TOKEN=<HF_AUTH_TOKEN_SECRET_NAME> --secret WANDB_API_KEY=<WANDB_API_KEY_SECRET_NAME> --env WANDB_PROJECT=<YOUR_PROJECT_NAME> \
  --nodes 1 --accels 2 \
  -- code/causal-language-modeling-qlora/train.py \
    --model_name_or_path meta-llama/Meta-Llama-3.1-8B \
    --dataset_name timdettmers/openassistant-guanaco \
    --tokenized_dataset_load_dir /input/llama-tokenized-oag \
    --dataset_text_field text \
    --load_in_4bit \
    --use_peft \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --output_dir /output-checkpoint \
    --log_level info
```

## Optional Extra Steps

### Manual Dataset Pre-processing

If you'd prefer to perform the dataset pre-processing step yourself, you can follow these instructions.

You can run these in a FlexAI Interactive Session or in a local env (e.g. `pipenv install --python 3.10`), if you have hardware that's capable of doing inference.

#### Clone this repository

If you haven't already, clone this repository on your host machine:

```bash theme={null}
git clone https://github.com/flexaihq/blueprints.git blueprints --depth 1 --branch main && cd blueprints
```

#### Install the dependencies

Depending on your environment, you might need to install - if not already - the experiments' dependencies by running:

```bash theme={null}
pip install -r requirements.txt
```

#### Dataset preparation

Prepare the dataset by running the following command:

```bash theme={null}
python dataset/prepare_save_sft.py \
  --model_name_or_path meta-llama/Meta-Llama-3.1-8B \
  --dataset_name timdettmers/openassistant-guanaco \
  --dataset_text_field text \
  --log_level info \
  --tokenized_dataset_save_dir llama-tokenized-oag \
  --output_dir ./.sft.output # This argument is not used but is required to use the SFT argument parser.
```

The prepared dataset will be saved to the `llama-tokenized-oag/` directory.

## Code

### `code/causal-language-modeling-qlora/train.py`

```python theme={null}
# Adapted from: https://github.com/huggingface/trl/blob/2cad48d511fab99ac0c4b327195523a575afcad3/examples/scripts/sft.py
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
from datasets import load_dataset
from transformers import AutoTokenizer
from trl import (
    ModelConfig,
    SFTConfig,
    SFTTrainer,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.commands.cli_utils import SFTScriptArguments, TrlParser

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from utils.experiment_tracking import set_wandb


@dataclass
class AdditionalArguments:
    """
    Additional arguments that are not part of the TRL arguments.
    """

    tokenized_dataset_load_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Directory to load the tokenized dataset using the "
                "`datasets.load_from_disk` method. No tokenization will be done"
                " the dataset will be loaded from this directory."
            )
        },
    )


if __name__ == "__main__":
    parser = TrlParser(
        (SFTScriptArguments, SFTConfig, ModelConfig, AdditionalArguments)
    )
    args, training_args, model_config, additional_args = parser.parse_args_and_config()
    set_wandb(training_args)
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=model_config.torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
        trust_remote_code=model_config.trust_remote_code,
        use_fast=True,
    )
    tokenizer.pad_token = tokenizer.eos_token

    if additional_args.tokenized_dataset_load_dir:
        dataset = datasets.load_from_disk(additional_args.tokenized_dataset_load_dir)
        skip_prepare_dataset = True
        if training_args.dataset_kwargs is None:
            training_args.dataset_kwargs = {}
        training_args.dataset_kwargs["skip_prepare_dataset"] = True
    else:
        dataset = load_dataset(args.dataset_name)

    trainer = SFTTrainer(
        model=model_config.model_name_or_path,
        args=training_args,
        train_dataset=dataset[args.dataset_train_split],
        eval_dataset=dataset[args.dataset_test_split],
        tokenizer=tokenizer,
        peft_config=get_peft_config(model_config),
    )

    trainer.train()
    trainer.save_model(training_args.output_dir)
```

### `code/causal-language-modeling-qlora/requirements.txt`

```text theme={null}
accelerate>=1.8.1
bitsandbytes>=0.43.2
datasets>=2.21.0
evaluate>=0.4.3
peft>=0.15.0
transformers>=4.43.3
trl==0.11.4
wandb>=0.18.1
```

<div className="blueprint-cta">
  <h3>🚀 Run this on FlexAI</h3>
  <p>Managed checkpoints mean you never lose a run to preemption. Jobs launch in under 60 seconds — no infra setup, built-in observability.</p>
  <a href="https://console.flex.ai" className="cta-primary">Get started →</a>
  <a href="https://flex.ai/contact" className="cta-secondary">Talk to us</a>
</div>
