# Copyright (c) 2025 FlexAI
# This file is part of the FlexAI Experiments repository.
# SPDX-License-Identifier: MIT
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import evaluate
import numpy
import torch
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
default_data_collator,
)
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from dataset.prepare_save_dataset import DatasetArguments, load_and_tokenize
from utils.experiment_tracking import set_wandb
transformers.logging.set_verbosity_info()
@dataclass
class ModelArguments:
model_name_or_path: str = field(default=None)
torch_dtype: Optional[str] = field(default=None)
attn_implementation: Optional[str] = field(default=None)
@dataclass
class AdditionalArguments:
max_train_samples: Optional[int] = field(default=None)
max_eval_samples: Optional[int] = field(default=None)
def parse_args():
parser = HfArgumentParser(
(DatasetArguments, ModelArguments, TrainingArguments, AdditionalArguments)
)
return parser.parse_args_into_dataclasses()
def _load_model_and_tokenizer(model_args, print_model=False):
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
extra_model_args = {}
if model_args.attn_implementation is not None:
extra_model_args["attn_implementation"] = model_args.attn_implementation
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
**extra_model_args,
)
if print_model:
print(model)
return model, tokenizer
def train(dataset_args, model_args, training_args, additional_args):
set_wandb(training_args)
print(f"Training/evaluation parameters {training_args}")
train_dataset, eval_dataset = load_and_tokenize(
tokenizer_model_name=model_args.model_name_or_path,
do_eval=training_args.do_eval,
**vars(dataset_args),
)
max_train_samples = float("inf")
max_eval_samples = float("inf")
if not dataset_args.dataset_streaming:
max_train_samples = len(train_dataset)
if training_args.do_eval:
max_eval_samples = len(eval_dataset)
if additional_args.max_train_samples is not None:
max_train_samples = min(max_train_samples, additional_args.max_train_samples)
train_dataset = train_dataset.take(max_train_samples)
if additional_args.max_eval_samples is not None:
assert training_args.do_eval, "Cannot set max_eval_samples without do_eval"
max_eval_samples = min(max_eval_samples, additional_args.max_eval_samples)
eval_dataset = eval_dataset.take(max_eval_samples)
model, tokenizer = _load_model_and_tokenizer(model_args, print_model=True)
metric = evaluate.load("accuracy")
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
logits = logits[0]
return logits.argmax(dim=-1)
def compute_metrics(eval_preds):
preds, labels = eval_preds
mask = (labels != tokenizer.pad_token_id) & (labels != -100)
labels = numpy.concatenate([label[mask[i]][1:] for i, label in enumerate(labels)])
preds = numpy.concatenate([pred[mask[i]][:-1] for i, pred in enumerate(preds)])
return metric.compute(predictions=preds, references=labels)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
metrics = train_result.metrics
metrics["train_samples"] = max_train_samples
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
if training_args.do_eval:
metrics = trainer.evaluate()
metrics["eval_samples"] = max_eval_samples
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
train(*parse_args())