Skip to main content
In some cases you might want to use large datasets that would be too large to download or push to FlexAI and you’d prefer to use that data transfer time more efficiently. Streaming such datasets can be a useful technique in those cases. This experiment demonstrates how to stream a large dataset during a Training Job on FlexAI. We’ll use the HuggingFace Datasets library’s Streaming capabilities to achieve this.

Connect to GitHub (if needed)

If you haven’t already connected FlexAI to GitHub, you’ll need to set up a code registry connection:
flexai code-registry connect
This will allow FlexAI to pull repositories directly from GitHub using the -u flag in training commands.

Running the Training Job streaming a dataset

Here is an example using the code/causal-language-modeling/train.py script to stream the over 90 TB Fineweb dataset:
flexai training run gpt2training-stream --repository-url https://github.com/flexaihq/blueprints --requirements-path code/causal-language-modeling/requirements.txt \
   -- code/causal-language-modeling/train.py \
    --dataset_streaming true \
    --do_train \
    --eval_strategy no \
    --dataset_name HuggingFaceFW/fineweb \
    --dataset_config_name CC-MAIN-2024-10 \
    --dataset_group_text true \
    --dataloader_num_workers 8 \
    --max_steps 2500 \
    --model_name_or_path openai-community/gpt2 \
    --output_dir /output-checkpoint \
    --per_device_train_batch_size 8 \
    --logging_steps 50 \
    --save_steps 1000
The first line defines the 3 main components required to run a Training Job in FlexAI:
  1. The Training Job’s name (gpt2training-stream).
  2. The URL of the repository containing the training script (https://github.com/flexaihq/blueprints).
  3. The name of the dataset to be used (empty-dataset or any other dataset you have available).
The second line defines the script that will be executed when the Training Job is started (code/causal-language-modeling/train.py). Below that, the first argument passed to the script is --dataset_streaming true, which value tells the script to use the Datasets library with streaming capabilities enabled. The next lines specify the arguments that will be passed to the training script during execution to adjust the Training Job’s hyperparameters or customize its behavior. For instance, --max_train_samples and --max_eval_samples can be used to tweak the sample size.

The code

You will notice that the train function in the code/causal-language-modeling/train.py script makes a call to the _load_model_and_tokenizer function to load the model and tokenizer using the user-provided arguments:
def train(dataset_args, model_args, training_args, additional_args):     # <--- 1. This is the function that will be called by the `flexai training run` command
    set_wandb(training_args)
    print(f"Training/evaluation parameters {training_args}")

    # Get dataset
    train_dataset, eval_dataset = load_and_tokenize(                     # <--- 2. Here the script calls the `load_and_tokenize` helper function
        tokenizer_model_name=model_args.model_name_or_path,
        do_eval=training_args.do_eval,
        **vars(dataset_args),                                            # <--- 3. These are the arguments passed to the script
    )
The load_and_tokenize helper function from the code/dataset/prepare_save_dataset.py file is the one responsible for using the HuggingFace’s Datasets library and enable its streaming capabilities by simply setting the load_dataset’s streaming argument to True:
def load_and_tokenize(
    dataset_name: str,
    # ...
    dataset_streaming: bool,
) -> Dict[str, Dataset]:
    # ...
    loaded_datasets = load_dataset(                                       # <--- 1. HuggingFace's Datasets library `load_dataset` function is called
        dataset_name,
        dataset_config_name,
        streaming=dataset_streaming,                                      # <--- 2. The `streaming` argument is set to `True`
    )
This is all that is needed to stream a dataset during a Training Job on FlexAI! You are no longer restricted by the challenges that come with large dataset transfer processes, and can now use them more efficiently.

Code

code/causal-language-modeling/train.py

# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT

import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import numpy
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
)

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from dataset.prepare_save_dataset import DatasetArguments, load_and_tokenize
from utils.experiment_tracking import set_wandb

transformers.logging.set_verbosity_info()


@dataclass
class ModelArguments:
    model_name_or_path: str = field(default=None)
    torch_dtype: Optional[str] = field(default=None)
    attn_implementation: Optional[str] = field(default=None)


@dataclass
class AdditionalArguments:
    max_train_samples: Optional[int] = field(default=None)
    max_eval_samples: Optional[int] = field(default=None)


def parse_args():
    parser = HfArgumentParser(
        (DatasetArguments, ModelArguments, TrainingArguments, AdditionalArguments)
    )
    return parser.parse_args_into_dataclasses()


def _load_model_and_tokenizer(model_args, print_model=False):
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    extra_model_args = {}
    if model_args.attn_implementation is not None:
        extra_model_args["attn_implementation"] = model_args.attn_implementation
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        torch_dtype=torch_dtype,
        **extra_model_args,
    )
    if print_model:
        print(model)
    return model, tokenizer


def train(dataset_args, model_args, training_args, additional_args):
    set_wandb(training_args)
    print(f"Training/evaluation parameters {training_args}")
    train_dataset, eval_dataset = load_and_tokenize(
        tokenizer_model_name=model_args.model_name_or_path,
        do_eval=training_args.do_eval,
        **vars(dataset_args),
    )
    max_train_samples = float("inf")
    max_eval_samples = float("inf")
    if not dataset_args.dataset_streaming:
        max_train_samples = len(train_dataset)
        if training_args.do_eval:
            max_eval_samples = len(eval_dataset)
    if additional_args.max_train_samples is not None:
        max_train_samples = min(max_train_samples, additional_args.max_train_samples)
        train_dataset = train_dataset.take(max_train_samples)
    if additional_args.max_eval_samples is not None:
        assert training_args.do_eval, "Cannot set max_eval_samples without do_eval"
        max_eval_samples = min(max_eval_samples, additional_args.max_eval_samples)
        eval_dataset = eval_dataset.take(max_eval_samples)
    model, tokenizer = _load_model_and_tokenizer(model_args, print_model=True)
    metric = evaluate.load("accuracy")

    def preprocess_logits_for_metrics(logits, labels):
        if isinstance(logits, tuple):
            logits = logits[0]
        return logits.argmax(dim=-1)

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        mask = (labels != tokenizer.pad_token_id) & (labels != -100)
        labels = numpy.concatenate([label[mask[i]][1:] for i, label in enumerate(labels)])
        preds = numpy.concatenate([pred[mask[i]][:-1] for i, pred in enumerate(preds)])
        return metric.compute(predictions=preds, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    trainer.save_model()
    metrics = train_result.metrics
    metrics["train_samples"] = max_train_samples
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    if training_args.do_eval:
        metrics = trainer.evaluate()
        metrics["eval_samples"] = max_eval_samples
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


if __name__ == "__main__":
    train(*parse_args())

code/causal-language-modeling/requirements.txt

accelerate>=1.8.1
datasets>=2.21.0
evaluate>=0.4.3
scikit_learn>=1.5.2
transformers>=4.43.3
wandb>=0.18.1