> ## Documentation Index
> Fetch the complete documentation index at: https://docs.flex.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# Run Distributed Data Parallel (DDP) Training on FlexAI

> Start a distributed DDP training job on FlexAI with just 2 flags. Multi-GPU training made simple — no SLURM, no infra config, launch in under 60 seconds.

This experiment demonstrates how easy it is to leverage **FlexAI** to run a Training Job with a couple of commands. We will use a simple example of training a causal language model (LLM) on the `wikitext` dataset using the `GPT-2` model.

You will see that this straightforward process only requires two components: a training script and a dataset. The training script is responsible for defining the model, setting up and applying hyperparameters, running the training loop, and applying its respective evaluation logic, while the dataset contains the information that will be used to train the model.

<Steps>
  <Step title="Connect to GitHub (if needed)">
    If you haven't already connected FlexAI to GitHub, you'll need to set up a code registry connection:

    ```bash theme={null}
    flexai code-registry connect
    ```

    This will allow FlexAI to pull repositories directly from GitHub using the `-u` flag in training commands.
  </Step>

  <Step title="Preparing the Dataset">
    In this experiment, we will use a pre-processed version of the the `wikitext` dataset that has been set up for the `GPT-2` model.

    > If you'd like to reproduce the pre-processing steps yourself to use a different dataset or simply to learn more about the process, you can refer to the [Manual Dataset Pre-processing](#manual-dataset-pre-processing) section below.

    1. Download the dataset:

       ```bash theme={null}
       DATASET_NAME=gpt2-tokenized-wikitext && curl -L -o ${DATASET_NAME}.zip "https://bucket-docs-samples-99b3a05.s3.eu-west-1.amazonaws.com/${DATASET_NAME}.zip" && unzip ${DATASET_NAME}.zip && rm ${DATASET_NAME}.zip
       ```

    2. Upload the dataset (located in `gpt2-tokenized-wikitext/`) to FlexAI:

       ```bash theme={null}
       flexai dataset push gpt2-tokenized-wikitext --file gpt2-tokenized-wikitext
       ```
  </Step>

  <Step title="Train the Model">
    Now, it's time to train your LLM on the dataset you just *pushed* in the previous step, `gpt2-tokenized-wikitext`. This experiment uses the `GPT-2` model, however, the training script we will use leverages the HuggingFace Transformers `Trainer` class, which makes it easy to replace `GPT-2` with another model from the HuggingFace Model Hub.

    To start the Training Job, run the following command:

    ```bash theme={null}
    flexai training run first-ddp-training-job --repository-url https://github.com/flexaihq/blueprints --dataset gpt2-tokenized-wikitext \
      --nodes 2 --accels 8 --requirements-path code/causal-language-modeling/requirements.txt \
      -- code/causal-language-modeling/train.py \
        --do_eval \
        --do_train \
        --dataset_name wikitext \
        --tokenized_dataset_load_dir /input/gpt2-tokenized-wikitext \
        --model_name_or_path openai-community/gpt2 \
        --output_dir /output-checkpoint \
        --per_device_train_batch_size 8 \
        --per_device_eval_batch_size 8 \
        --logging_steps 50 \
        --save_steps 500 \
        --eval_steps 500 \
        --eval_strategy steps
    ```

    The first line defines the 3 main components required to run a Training Job in FlexAI:

    1. The Training Job's name (`first-ddp-training-job`).
    2. The URL of the repository containing the training script (`https://github.com/flexaihq/blueprints`).
    3. The name of the dataset to be used (`gpt2-tokenized-wikitext`).

    The second line allows you to pick the number of nodes you want to use for the Training Job, as well as the number of accelerators per each node. In this case the Training Job will run on 16 GPUs distributed across 2 nodes.

    The third line defines the script that will be executed when the Training Job is started (`code/causal-language-modeling/train.py`).

    After the third line come the script's arguments, which are passed to the script when it is executed to adjust the Training Job hyperparameters or customize its behavior. For instance, `--max_train_samples` and `--max_eval_samples` can be used to tweak the sample size.
  </Step>

  <Step title="Checking up on the Training Job">
    You can check the status and life cycle events of your Training Job by running:

    ```bash theme={null}
    flexai training inspect first-ddp-training-job
    ```

    Additionally, you can view the logs of your Training Job by running:

    ```bash theme={null}
    flexai training logs first-ddp-training-job
    ```
  </Step>

  <Step title="Fetching the Trained Model artifacts">
    Once the Training Job completes successfully, you will be able to list all the produced checkpoints:

    ```bash theme={null}
    flexai training checkpoints first-ddp-training-job
    ```

    They can be downloaded with:

    ```bash theme={null}
    flexai checkpoint fetch "<CHECKPOINT-ID>"
    ```

    You now have a trained model that you can use for inference or further fine-tuning! Check out the [Extra](#optional-extra-steps) section below for more information on how to run your fine-tuned model locally, or even better, how to run the training script directly on FlexAI using an Interactive Training Session. You can also learn how to manually pre-process the dataset if you're interested in understanding the process better.

    You can also have a look at other FlexAI experiments within this repository to explore more advanced use cases and techniques.
  </Step>
</Steps>

## Optional Extra Steps

### Try your fine-tuned model locally

You can run your newly fine-tuned model in a FlexAI Interactive Session or in a local env (e.g. `pipenv install --python 3.10`), if you have hardware that's capable of doing inference.

#### 1. Clone this repository

If you haven't already, clone this repository on your host machine:

```bash theme={null}
git clone https://github.com/flexaihq/blueprints.git blueprints --depth 1 --branch main && cd blueprints
```

#### 2. Install the dependencies

Depending on your environment, you might need to install - if not already - the experiments' dependencies by running:

```bash theme={null}
pip install -r code/causal-language-modeling/requirements.txt
```

#### 3. Extract the model artifacts

First, list the available checkpoints from your training job:

```bash theme={null}
flexai training checkpoints first-ddp-training-job
```

Then fetch the specific checkpoint you want to use (replace `<CHECKPOINT-ID>` with the actual checkpoint ID from the list):

```bash theme={null}
flexai checkpoint fetch "<CHECKPOINT-ID>" --destination ./checkpoint
```

This will download the checkpoint to a local `checkpoint` directory. Make note of this location, as you will use it next.

#### 4. Run the inference script

Run the script made for inference on this model by running the command below, replacing `**PATH_TO_THE_CHECKPOINT_DIRECTORY**` with the path to the checkpoint directory you downloaded:

```bash theme={null}
python code/causal-language-modeling/predict.py \
    --model_name_or_path **PATH_TO_THE_CHECKPOINT_DIRECTORY** \
    --input_str "Once upon a time, " \
    --max_new_tokens 30
```

### Run the training script directly on FlexAI using an Interactive Training Session

An Interactive Training Session allows you to connect to a Training Environment runtime on FlexAI and run your both training and prediction or inference scripts directly from this environment. This is a great way to test your scripts and experiment with different hyperparameters without having to create multiple Training Jobs per configuration change.

You will find the guide on how to run an Interactive Training Session in the [FlexAI Documentation](https://docs.flex.ai/cli/guides/interactive-training/). You'll need to use the path for the `flexaihq/blueprints` repository as your `--repository-url` and pass the `gpt2-tokenized-wikitext` dataset you pushed earlier as `--dataset`, unless you want to leverage the Interactive Training Session's compute resources to manually pre-process the dataset.

### Manual Dataset Pre-processing

To prepare and save the `wikitext` dataset for the `GPT-2` model run the following command:

```bash theme={null}
python code/dataset/prepare_save_dataset.py \
    --dataset_name wikitext \
    --tokenized_dataset_save_dir gpt2-tokenized-wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --tokenizer_model_name openai-community/gpt2 \
    --dataset_group_text true
```

The generated dataset will be created in the directory set as the value of `--tokenized_dataset_save_dir`, in this case: `gpt2-tokenized-wikitext`.

Keep in mind that you can use other combinations of datasets and models available on HuggingFace.

## Code

### `code/causal-language-modeling/train.py`

```python theme={null}
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT

import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import numpy
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
)

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from dataset.prepare_save_dataset import DatasetArguments, load_and_tokenize
from utils.experiment_tracking import set_wandb

transformers.logging.set_verbosity_info()


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None)
    torch_dtype: Optional[str] = field(default=None)
    attn_implementation: Optional[str] = field(default=None)


@dataclass
class AdditionalArguments:
    max_train_samples: Optional[int] = field(default=None)
    max_eval_samples: Optional[int] = field(default=None)


def parse_args():
    parser = HfArgumentParser(
        (DatasetArguments, ModelArguments, TrainingArguments, AdditionalArguments)
    )
    return parser.parse_args_into_dataclasses()


def _load_model_and_tokenizer(model_args, print_model=False):
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    extra_model_args = {}
    if model_args.attn_implementation is not None:
        extra_model_args["attn_implementation"] = model_args.attn_implementation
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        torch_dtype=torch_dtype,
        **extra_model_args,
    )
    if print_model:
        print(model)
    return model, tokenizer


def train(dataset_args, model_args, training_args, additional_args):
    set_wandb(training_args)
    print(f"Training/evaluation parameters {training_args}")
    train_dataset, eval_dataset = load_and_tokenize(
        tokenizer_model_name=model_args.model_name_or_path,
        do_eval=training_args.do_eval,
        **vars(dataset_args),
    )
    max_train_samples = float("inf")
    max_eval_samples = float("inf")
    if not dataset_args.dataset_streaming:
        max_train_samples = len(train_dataset)
        if training_args.do_eval:
            max_eval_samples = len(eval_dataset)
    if additional_args.max_train_samples is not None:
        max_train_samples = min(max_train_samples, additional_args.max_train_samples)
        train_dataset = train_dataset.take(max_train_samples)
    if additional_args.max_eval_samples is not None:
        assert training_args.do_eval, "Cannot set max_eval_samples without do_eval"
        max_eval_samples = min(max_eval_samples, additional_args.max_eval_samples)
        eval_dataset = eval_dataset.take(max_eval_samples)
    model, tokenizer = _load_model_and_tokenizer(model_args, print_model=True)
    metric = evaluate.load("accuracy")

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        mask = (labels != tokenizer.pad_token_id) & (labels != -100)
        labels = numpy.concatenate([label[mask[i]][1:] for i, label in enumerate(labels)])
        preds = numpy.concatenate([pred[mask[i]][:-1] for i, pred in enumerate(preds)])
        return metric.compute(predictions=preds, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_model()
    metrics = train_result.metrics
    metrics["train_samples"] = max_train_samples
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    if training_args.do_eval:
        metrics = trainer.evaluate()
        metrics["eval_samples"] = max_eval_samples
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    train(*parse_args())
```

### `code/causal-language-modeling/predict.py`

```python theme={null}
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT

import os
import sys
from dataclasses import dataclass, field

from transformers import HfArgumentParser

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(SCRIPT_DIR)

from train import ModelArguments, _load_model_and_tokenizer


@dataclass
class PredictArguments:
    input_str: str = field(metadata={"help": "Input string to generate predictions for."})
    max_new_tokens: int = field(default=100, metadata={"help": "Maximum number of tokens to generate."})


def parse_args():
    parser = HfArgumentParser((ModelArguments, PredictArguments))
    return parser.parse_args_into_dataclasses()


def predict(model_args, predict_args):
    model, tokenizer = _load_model_and_tokenizer(model_args, None)
    input_str = predict_args.input_str
    model_inputs = tokenizer(input_str, return_tensors="pt")
    output = model.generate(**model_inputs, max_new_tokens=predict_args.max_new_tokens)
    print("\nGenerated text:")
    print(tokenizer.decode(output[0], skip_special_tokens=True))


if __name__ == "__main__":
    predict(*parse_args())
```

### `code/causal-language-modeling/requirements.txt`

```text theme={null}
accelerate>=1.8.1
datasets>=2.21.0
evaluate>=0.4.3
scikit_learn>=1.5.2
transformers>=4.43.3
wandb>=0.18.1
```

<div className="blueprint-cta">
  <h3>🚀 Run this on FlexAI</h3>
  <p>Managed checkpoints mean you never lose a run to preemption. Jobs launch in under 60 seconds — no infra setup, built-in observability.</p>
  <a href="https://console.flex.ai" className="cta-primary">Get started →</a>
  <a href="https://flex.ai/contact" className="cta-secondary">Talk to us</a>
</div>
