Skip to content

How to use Managed Checkpoints

FlexAI Managed Checkpoints are a powerful feature that automatically saves and manages checkpoints during a Training Job’s execution.

This feature has been designed to simplify the process of saving and retrieving the state of a model during training, allowing you to easily resume training from a specific point in time to evaluate the model’s performance, go back in time to re-start a Training Job from a Checkpoint before an interruption, or even roll back to a previous state if needed.

FlexAI Managed Checkpoints automatically captures checkpoints every time the Training Job’s code calls the torch.save() 🔗 function, which is commonly used in PyTorch-based training scripts to save the model’s state.

In order to leverage the FlexAI Managed Checkpoints feature, the torch.save function needs to store them in a dedicated directory, which path is stored in the FLEXAI_OUTPUT_CHECKPOINT_DIR environment variable of the Training Runtime.

Whenever a new Checkpoint is stored to the directory the FLEXAI_OUTPUT_CHECKPOINT_DIR environment variable points to, FlexAI Managed Checkpoints will automatically capture it and assign it a unique ID, along with the timestamp of when it was created.

A Training Job’s Checkpoints can be listed using the flexai training checkpoints command.

Checkpoints can be managed using the flexai checkpoint family of commands, which allow you to fetch them to a host machine, export them to a remote storage location, inspect them to get further details, and delete them.

When a Training Job is run using the -C / --checkpoint flag, FlexAI Managed Checkpoints will automatically load the specified Checkpoint into the path the FLEXAI_INPUT_CHECKPOINT_DIR environment variable points to. This allows your training code to resume execution from the state captured in the Checkpoint.


Currently, the FlexAI runtime supports Hugging Face Transformers checkpoints, which include the trainer_state.json and config.json files that contain metadata about the training process and model configuration:

  • STEP, TRAIN LOSS & EVAL LOSS: Extracted from trainer_state.json’s log_history field (last entry).
  • MODEL: Determined from config.json’s architectures field.
  • VERSION: Retrieved from config.json’s transformers_version field.
  • INFERENCE READY: Set to true if the architectures field is present in config.json.