> ## Documentation Index
> Fetch the complete documentation index at: https://docs.flex.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# Track Training Experiments with Weights & Biases on FlexAI

> Integrate Weights & Biases experiment tracking with FlexAI training jobs. Monitor loss curves, compare runs, and share results — zero extra infra.

## Experiment Tracking

*Experiment tracking* involves systematically recording and managing details of machine learning experiments, such as code, data, configurations, parameters, metrics, and results.
It ensures reproducibility, comparability, and accountability across experiments, aiding in efficient model development and deployment.
Weights & Biases (*wandb*) is one approach to achieving this.

Follow the next instructions to log experiments to your *wandb* account.

## Setting Up the Weights and Biases Secret

To enable seamless integration with *wandb* in your experiments, follow these steps to create the *wandb* secret:

1. **Retrieve Your API Key**

   Visit your [Weights & Biases Settings page](https://app.wandb.ai/settings) to find your API key. Copy the key for use in the next step.

2. **Create the Secret**

   Use the [`flexai secret create` command](https://docs.flex.ai/cli/commands/secret/) to store your *wandb* API key as a secret. Replace `<WANDB_API_KEY_SECRET_NAME>` with your desired name for the secret:

   ```bash theme={null}
   flexai secret create <WANDB_API_KEY_SECRET_NAME>
   ```

   Then past your *wandb* API key value.

3. **Note on Project Name**

   Keep in mind that the project name used in your *wandb* setup does not need to be an FlexAI Secret. Additionally, the project name does not need to be pre-created in *wandb* — it will be automatically created if it doesn't exist when you log your first experiment.

## Log to Weights and Biases

You can now log experiments to your *wandb* account by adding the following flags to any `flexai training run` command:

```bash theme={null}
--secret WANDB_API_KEY=<WANDB_API_KEY_SECRET_NAME> --env WANDB_PROJECT=<YOUR_PROJECT_NAME>
```

You can optionally set your *run name* using the `--run_name <YOUR_RUN_NAME>` HuggingFace arg.

For more ways to customize and configure your *wandb* environment, check out the [Weights & Biases Environment Variables Guide](https://docs.wandb.ai/guides/track/environment-variables/).

## Setting Up the Experiment

<Steps>
  <Step title="Connect to GitHub (if needed)">
    If you haven't already connected FlexAI to GitHub, you'll need to set up a code registry connection:

    ```bash theme={null}
    flexai code-registry connect
    ```

    This will allow FlexAI to pull repositories directly from GitHub using the `-u` flag in training commands.
  </Step>

  <Step title="Preparing the Dataset">
    In this experiment, we will use a pre-processed version of the the `wikitext` dataset that has been set up for the `GPT-2` model.

    1. Download the dataset:

       ```bash theme={null}
       DATASET_NAME=gpt2-tokenized-wikitext && curl -L -o ${DATASET_NAME}.zip "https://bucket-docs-samples-99b3a05.s3.eu-west-1.amazonaws.com/${DATASET_NAME}.zip" && unzip ${DATASET_NAME}.zip && rm ${DATASET_NAME}.zip
       ```

    2. Upload the dataset (located in `gpt2-tokenized-wikitext/`) to FlexAI:

       ```bash theme={null}
       flexai dataset push gpt2-tokenized-wikitext --file gpt2-tokenized-wikitext
       ```
  </Step>
</Steps>

## Running the Training Job with Experiment Tracking

Now that all the pieces are in place (*wandb* Secret, Source, and Dataset), you can run the training job with experiment tracking enabled.

```bash theme={null}
flexai training run gpt2training-tracker --repository-url https://github.com/flexaihq/blueprints --dataset gpt2-tokenized-wikitext --secret WANDB_API_KEY=<WANDB_API_KEY_SECRET_NAME> --env WANDB_PROJECT=<YOUR_PROJECT_NAME> --requirements-path code/causal-language-modeling/requirements.txt \
  -- code/causal-language-modeling/train.py \
    --do_eval \
    --do_train \
    --dataset_name wikitext \
    --tokenized_dataset_load_dir /input/gpt2-tokenized-wikitext \
    --model_name_or_path openai-community/gpt2 \
    --output_dir /output-checkpoint \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_steps 50 \
    --save_steps 500 \
    --eval_steps 500 \
    --eval_strategy steps \
    --run_name <YOUR_RUN_NAME>
```

You can now visit your *wandb* dashboard and look for your project's name to follow the progress of the Training Job and analyze its results in near real-time.

## Code

### `code/causal-language-modeling/train.py`

```python theme={null}
# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT

import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import numpy
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
)

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from dataset.prepare_save_dataset import DatasetArguments, load_and_tokenize
from utils.experiment_tracking import set_wandb

transformers.logging.set_verbosity_info()


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None)
    torch_dtype: Optional[str] = field(default=None)
    attn_implementation: Optional[str] = field(default=None)


@dataclass
class AdditionalArguments:
    max_train_samples: Optional[int] = field(default=None)
    max_eval_samples: Optional[int] = field(default=None)


def parse_args():
    parser = HfArgumentParser(
        (DatasetArguments, ModelArguments, TrainingArguments, AdditionalArguments)
    )
    return parser.parse_args_into_dataclasses()


def _load_model_and_tokenizer(model_args, print_model=False):
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    extra_model_args = {}
    if model_args.attn_implementation is not None:
        extra_model_args["attn_implementation"] = model_args.attn_implementation
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        torch_dtype=torch_dtype,
        **extra_model_args,
    )
    if print_model:
        print(model)
    return model, tokenizer


def train(dataset_args, model_args, training_args, additional_args):
    set_wandb(training_args)
    print(f"Training/evaluation parameters {training_args}")
    train_dataset, eval_dataset = load_and_tokenize(
        tokenizer_model_name=model_args.model_name_or_path,
        do_eval=training_args.do_eval,
        **vars(dataset_args),
    )
    max_train_samples = float("inf")
    max_eval_samples = float("inf")
    if not dataset_args.dataset_streaming:
        max_train_samples = len(train_dataset)
        if training_args.do_eval:
            max_eval_samples = len(eval_dataset)
    if additional_args.max_train_samples is not None:
        max_train_samples = min(max_train_samples, additional_args.max_train_samples)
        train_dataset = train_dataset.take(max_train_samples)
    if additional_args.max_eval_samples is not None:
        assert training_args.do_eval, "Cannot set max_eval_samples without do_eval"
        max_eval_samples = min(max_eval_samples, additional_args.max_eval_samples)
        eval_dataset = eval_dataset.take(max_eval_samples)
    model, tokenizer = _load_model_and_tokenizer(model_args, print_model=True)
    metric = evaluate.load("accuracy")

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        mask = (labels != tokenizer.pad_token_id) & (labels != -100)
        labels = numpy.concatenate([label[mask[i]][1:] for i, label in enumerate(labels)])
        preds = numpy.concatenate([pred[mask[i]][:-1] for i, pred in enumerate(preds)])
        return metric.compute(predictions=preds, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_model()
    metrics = train_result.metrics
    metrics["train_samples"] = max_train_samples
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    if training_args.do_eval:
        metrics = trainer.evaluate()
        metrics["eval_samples"] = max_eval_samples
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    train(*parse_args())
```

### `code/causal-language-modeling/requirements.txt`

```text theme={null}
accelerate>=1.8.1
datasets>=2.21.0
evaluate>=0.4.3
scikit_learn>=1.5.2
transformers>=4.43.3
wandb>=0.18.1
```

<div className="blueprint-cta">
  <h3>🚀 Run this on FlexAI</h3>
  <p>Managed checkpoints mean you never lose a run to preemption. Jobs launch in under 60 seconds — no infra setup, built-in observability.</p>
  <a href="https://console.flex.ai" className="cta-primary">Get started →</a>
  <a href="https://flex.ai/contact" className="cta-secondary">Talk to us</a>
</div>
