Skip to main content

Experiment Tracking

Experiment tracking involves systematically recording and managing details of machine learning experiments, such as code, data, configurations, parameters, metrics, and results. It ensures reproducibility, comparability, and accountability across experiments, aiding in efficient model development and deployment. Weights & Biases (wandb) is one approach to achieving this. Follow the next instructions to log experiments to your wandb account.

Setting Up the Weights and Biases Secret

To enable seamless integration with wandb in your experiments, follow these steps to create the wandb secret:
  1. Retrieve Your API Key Visit your Weights & Biases Settings page to find your API key. Copy the key for use in the next step.
  2. Create the Secret Use the flexai secret create command to store your wandb API key as a secret. Replace <WANDB_API_KEY_SECRET_NAME> with your desired name for the secret:
    flexai secret create <WANDB_API_KEY_SECRET_NAME>
    
    Then past your wandb API key value.
  3. Note on Project Name Keep in mind that the project name used in your wandb setup does not need to be an FlexAI Secret. Additionally, the project name does not need to be pre-created in wandb — it will be automatically created if it doesn’t exist when you log your first experiment.

Log to Weights and Biases

You can now log experiments to your wandb account by adding the following flags to any flexai training run command:
--secret WANDB_API_KEY=<WANDB_API_KEY_SECRET_NAME> --env WANDB_PROJECT=<YOUR_PROJECT_NAME>
You can optionally set your run name using the --run_name <YOUR_RUN_NAME> HuggingFace arg. For more ways to customize and configure your wandb environment, check out the Weights & Biases Environment Variables Guide.

Setting Up the Experiment

1

Connect to GitHub (if needed)

If you haven’t already connected FlexAI to GitHub, you’ll need to set up a code registry connection:
flexai code-registry connect
This will allow FlexAI to pull repositories directly from GitHub using the -u flag in training commands.
2

Preparing the Dataset

In this experiment, we will use a pre-processed version of the the wikitext dataset that has been set up for the GPT-2 model.
  1. Download the dataset:
    DATASET_NAME=gpt2-tokenized-wikitext && curl -L -o ${DATASET_NAME}.zip "https://bucket-docs-samples-99b3a05.s3.eu-west-1.amazonaws.com/${DATASET_NAME}.zip" && unzip ${DATASET_NAME}.zip && rm ${DATASET_NAME}.zip
    
  2. Upload the dataset (located in gpt2-tokenized-wikitext/) to FlexAI:
    flexai dataset push gpt2-tokenized-wikitext --file gpt2-tokenized-wikitext
    

Running the Training Job with Experiment Tracking

Now that all the pieces are in place (wandb Secret, Source, and Dataset), you can run the training job with experiment tracking enabled.
flexai training run gpt2training-tracker --repository-url https://github.com/flexaihq/blueprints --dataset gpt2-tokenized-wikitext --secret WANDB_API_KEY=<WANDB_API_KEY_SECRET_NAME> --env WANDB_PROJECT=<YOUR_PROJECT_NAME> --requirements-path code/causal-language-modeling/requirements.txt \
  -- code/causal-language-modeling/train.py \
    --do_eval \
    --do_train \
    --dataset_name wikitext \
    --tokenized_dataset_load_dir /input/gpt2-tokenized-wikitext \
    --model_name_or_path openai-community/gpt2 \
    --output_dir /output-checkpoint \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_steps 50 \
    --save_steps 500 \
    --eval_steps 500 \
    --eval_strategy steps \
    --run_name <YOUR_RUN_NAME>
You can now visit your wandb dashboard and look for your project’s name to follow the progress of the Training Job and analyze its results in near real-time.

Code

code/causal-language-modeling/train.py

# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT

import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import numpy
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
)

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from dataset.prepare_save_dataset import DatasetArguments, load_and_tokenize
from utils.experiment_tracking import set_wandb

transformers.logging.set_verbosity_info()


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None)
    torch_dtype: Optional[str] = field(default=None)
    attn_implementation: Optional[str] = field(default=None)


@dataclass
class AdditionalArguments:
    max_train_samples: Optional[int] = field(default=None)
    max_eval_samples: Optional[int] = field(default=None)


def parse_args():
    parser = HfArgumentParser(
        (DatasetArguments, ModelArguments, TrainingArguments, AdditionalArguments)
    )
    return parser.parse_args_into_dataclasses()


def _load_model_and_tokenizer(model_args, print_model=False):
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    extra_model_args = {}
    if model_args.attn_implementation is not None:
        extra_model_args["attn_implementation"] = model_args.attn_implementation
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        torch_dtype=torch_dtype,
        **extra_model_args,
    )
    if print_model:
        print(model)
    return model, tokenizer


def train(dataset_args, model_args, training_args, additional_args):
    set_wandb(training_args)
    print(f"Training/evaluation parameters {training_args}")
    train_dataset, eval_dataset = load_and_tokenize(
        tokenizer_model_name=model_args.model_name_or_path,
        do_eval=training_args.do_eval,
        **vars(dataset_args),
    )
    max_train_samples = float("inf")
    max_eval_samples = float("inf")
    if not dataset_args.dataset_streaming:
        max_train_samples = len(train_dataset)
        if training_args.do_eval:
            max_eval_samples = len(eval_dataset)
    if additional_args.max_train_samples is not None:
        max_train_samples = min(max_train_samples, additional_args.max_train_samples)
        train_dataset = train_dataset.take(max_train_samples)
    if additional_args.max_eval_samples is not None:
        assert training_args.do_eval, "Cannot set max_eval_samples without do_eval"
        max_eval_samples = min(max_eval_samples, additional_args.max_eval_samples)
        eval_dataset = eval_dataset.take(max_eval_samples)
    model, tokenizer = _load_model_and_tokenizer(model_args, print_model=True)
    metric = evaluate.load("accuracy")

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        mask = (labels != tokenizer.pad_token_id) & (labels != -100)
        labels = numpy.concatenate([label[mask[i]][1:] for i, label in enumerate(labels)])
        preds = numpy.concatenate([pred[mask[i]][:-1] for i, pred in enumerate(preds)])
        return metric.compute(predictions=preds, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_model()
    metrics = train_result.metrics
    metrics["train_samples"] = max_train_samples
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    if training_args.do_eval:
        metrics = trainer.evaluate()
        metrics["eval_samples"] = max_eval_samples
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    train(*parse_args())

code/causal-language-modeling/requirements.txt

accelerate>=1.8.1
datasets>=2.21.0
evaluate>=0.4.3
scikit_learn>=1.5.2
transformers>=4.43.3
wandb>=0.18.1