Skip to main content
This experiment is temporarily disabled.
This experiment demonstrates how easy it is to leverage FlexAI to run a Training Job making use of Flash Attention through the flash-attention package with a couple of commands. We will use an example of training a causal language model (LLM) on the wikitext dataset using the GPT-2 model. You will see that this straightforward process only requires two components: a training script and a dataset. The training script is responsible for defining the model, setting up and applying hyperparameters, running the training loop, and applying its respective evaluation logic, while the dataset contains the information that will be used to train the model.
1

Connect to GitHub (if needed)

If you haven’t already connected FlexAI to GitHub, you’ll need to set up a code registry connection:
flexai code-registry connect
This will allow FlexAI to pull repositories directly from GitHub using the -u flag in training commands.
2

Preparing the Dataset

In this experiment, we will use a pre-processed version of the the wikitext dataset that has been set up for the GPT-2 model.
If you’d like to reproduce the pre-processing steps yourself to use a different dataset or simply to learn more about the process, you can refer to the Manual Dataset Pre-processing section below.
  1. Download the dataset:
    DATASET_NAME=gpt2-tokenized-wikitext && curl -L -o ${DATASET_NAME}.zip "https://bucket-docs-samples-99b3a05.s3.eu-west-1.amazonaws.com/${DATASET_NAME}.zip" && unzip ${DATASET_NAME}.zip && rm ${DATASET_NAME}.zip
    
  2. Upload the dataset (located in gpt2-tokenized-wikitext/) to FlexAI Storage as a new dataset:
    flexai dataset push gpt2-tokenized-wikitext --file gpt2-tokenized-wikitext
    
3

Train the Model

Now, it’s time to train your LLM on the dataset you just pushed in the previous step, gpt2-tokenized-wikitext. This experiment uses the GPT-2 model, however, the training script we will use leverages the HuggingFace Transformers Trainer class, which makes it easy to replace GPT-2 with another model compatible with flash-attention.To start the Training Job, run the following command:
flexai training run flexai-experiments-flash-attention --repository-url https://github.com/flexaihq/blueprints --dataset gpt2-tokenized-wikitext --requirements-path code/causal-language-modeling/requirements-flash-attn.txt \
 -- code/causal-language-modeling/train.py \
    --do_eval \
    --do_train \
    --dataset_name wikitext \
    --tokenized_dataset_load_dir /input/gpt2-tokenized-wikitext \
    --model_name_or_path openai-community/gpt2 \
    --output_dir /output-checkpoint \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_steps 50 \
    --save_steps 500 \
    --eval_steps 500 \
    --attn_implementation flash_attention_2 \
    --torch_dtype float16 \
    --eval_strategy steps
The first line defines the 3 main components required to run a Training Job in FlexAI Storage:
  1. The Training Job’s name (flexai-experiments-flash-attention).
  2. The URL of the repository containing the training script (https://github.com/flexaihq/blueprints).
  3. The name of the dataset to be used (gpt2-tokenized-wikitext).
The second line defines the script that will be executed when the Training Job is started (code/causal-language-modeling/train.py).After the second line come the script’s arguments, which are passed to the script when it is executed to adjust the Training Job hyperparameters or customize its behavior. For instance, --max_train_samples and --max_eval_samples can be used to tweak the sample size.
4

Checking up on the Training Job

You can check the status and life cycle events of your Training Job by running:
flexai training inspect flexai-experiments-flash-attention
Additionally, you can view the logs of your Training Job by running:
flexai training logs flexai-experiments-flash-attention
5

Fetching the Trained Model artifacts

Once the Training Job completes successfully, you will be able to download its output artifacts by running:
flexai training fetch flexai-experiments-flash-attention
This will download a zip file containing the trained model artifacts to your current working directory.You can now have a look at other Experiments within this repository to explore other use cases and techniques.

Optional Extra Steps

Manual Dataset Pre-processing

To prepare and save the wikitext dataset for the GPT-2 model run the following command:
python code/dataset/prepare_save_dataset.py \
    --dataset_name wikitext \
    --tokenized_dataset_save_dir gpt2-tokenized-wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --tokenizer_model_name openai-community/gpt2 \
    --dataset_group_text true
The generated dataset will be created in the directory set as the value of --tokenized_dataset_save_dir, in this case: gpt2-tokenized-wikitext. Keep in mind that you can use other combinations of datasets and models available on HuggingFace.

Code

code/causal-language-modeling/train.py

# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT

import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import numpy
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
)

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from dataset.prepare_save_dataset import DatasetArguments, load_and_tokenize
from utils.experiment_tracking import set_wandb

transformers.logging.set_verbosity_info()


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None)
    torch_dtype: Optional[str] = field(default=None)
    attn_implementation: Optional[str] = field(default=None)


@dataclass
class AdditionalArguments:
    max_train_samples: Optional[int] = field(default=None)
    max_eval_samples: Optional[int] = field(default=None)


def parse_args():
    parser = HfArgumentParser(
        (DatasetArguments, ModelArguments, TrainingArguments, AdditionalArguments)
    )
    return parser.parse_args_into_dataclasses()


def _load_model_and_tokenizer(model_args, print_model=False):
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    extra_model_args = {}
    if model_args.attn_implementation is not None:
        extra_model_args["attn_implementation"] = model_args.attn_implementation
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        torch_dtype=torch_dtype,
        **extra_model_args,
    )
    if print_model:
        print(model)
    return model, tokenizer


def train(dataset_args, model_args, training_args, additional_args):
    set_wandb(training_args)
    print(f"Training/evaluation parameters {training_args}")
    train_dataset, eval_dataset = load_and_tokenize(
        tokenizer_model_name=model_args.model_name_or_path,
        do_eval=training_args.do_eval,
        **vars(dataset_args),
    )
    max_train_samples = float("inf")
    max_eval_samples = float("inf")
    if not dataset_args.dataset_streaming:
        max_train_samples = len(train_dataset)
        if training_args.do_eval:
            max_eval_samples = len(eval_dataset)
    if additional_args.max_train_samples is not None:
        max_train_samples = min(max_train_samples, additional_args.max_train_samples)
        train_dataset = train_dataset.take(max_train_samples)
    if additional_args.max_eval_samples is not None:
        assert training_args.do_eval, "Cannot set max_eval_samples without do_eval"
        max_eval_samples = min(max_eval_samples, additional_args.max_eval_samples)
        eval_dataset = eval_dataset.take(max_eval_samples)
    model, tokenizer = _load_model_and_tokenizer(model_args, print_model=True)
    metric = evaluate.load("accuracy")

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        mask = (labels != tokenizer.pad_token_id) & (labels != -100)
        labels = numpy.concatenate([label[mask[i]][1:] for i, label in enumerate(labels)])
        preds = numpy.concatenate([pred[mask[i]][:-1] for i, pred in enumerate(preds)])
        return metric.compute(predictions=preds, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_model()
    metrics = train_result.metrics
    metrics["train_samples"] = max_train_samples
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    if training_args.do_eval:
        metrics = trainer.evaluate()
        metrics["eval_samples"] = max_eval_samples
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    train(*parse_args())

code/causal-language-modeling/requirements-flash-attn.txt

accelerate>=1.8.1
datasets>=2.21.0
evaluate>=0.4.3
# flash-attn==2.7.4.post1 disable for now.
scikit_learn>=1.5.2
transformers>=4.43.3
wandb>=0.18.1