Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Checkpointing on HPC Clusters

Training jobs on shared clusters can be interrupted at any time because they hit the wall-time limit (RCC: 36h, DSI: 12h). Checkpointing saves your training state so that when a job restarts, it picks up where it left off rather than starting over.

How SLURM preemption works

When SLURM needs to kill your job, it can be configured to send a warning signal first. Adding this line to your batch script tells SLURM to send SIGUSR2 to your process 60 seconds before the kill:

#SBATCH --signal=USR2@60

That 60-second window is your opportunity to save a checkpoint and requeue the job before SLURM terminates it.

The pattern

The approach used by training frameworks has three components:

1. A preemption flag

A mutable container (a dict) that lives at module scope so a signal handler can modify it:

preemption_flag = dict(flag=False)

A plain bool variable won’t work here — assignment inside a nested function creates a new local variable rather than modifying the outer one. A dict is mutable, so preemption_flag["flag"] = True modifies the existing object.

2. A signal handler registered before the training loop

import signal

def set_preemption_flag(signum, frame):
    print(f"Signal {signum} received — will checkpoint and requeue.")
    preemption_flag["flag"] = True

signal.signal(signal.SIGUSR2, set_preemption_flag)

This must be called before the training loop starts. When SLURM sends SIGUSR2, Python interrupts the current instruction, runs set_preemption_flag, then resumes — so the flag is set without crashing the process.

3. A check at the end of each training step

if preemption_flag["flag"]:
    save_checkpoint(step)
    requeue()

Checking at the step boundary (not mid-step) means you always save a clean, consistent state.

Small example

The following self-contained script trains a small linear model and implements the full pattern:

checkpointing_example.py
# checkpointing_example.py
import os
import sys
import signal
import torch
import torch.nn as nn
from pathlib import Path

CHECKPOINT_PATH = Path("checkpoint.pt")
SAVE_EVERY = 10  # also save every N steps as insurance

# -- 1. Preemption flag --------------------------------------------------------
PREEMPTION_FLAG = dict(flag=False)  # module variable  # 1

# -- 2. A signal handler registered before the training loop -------------------
def set_preemption_flag(signum, frame):
    print(f"Signal {signum} received — will checkpoint and requeue.")
    PREEMPTION_FLAG["flag"] = True  # closes over the module variable

# Checkpoint and requeue helpers
def save_checkpoint(step, model, optimizer):
    torch.save(
        {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step},
        CHECKPOINT_PATH,
    )
    print(f"Checkpoint saved at step {step}")

def requeue():
    job_id = os.environ.get("SLURM_JOB_ID")
    if job_id:
        print(f"Requeueing job {job_id}")
        os.system(f"scontrol requeue {job_id}")
    sys.exit(0)

def main():
    """Main execution flow."""

    # -- 2. A signal handler registered before the training loop ---------------
    signal.signal(signal.SIGUSR2, set_preemption_flag)  # 2

    # Model and optimizer
    model = nn.Linear(16, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    start_step = 0

    # Resume checkpoint if one exists
    if CHECKPOINT_PATH.exists():  # 3
        ckpt = torch.load(CHECKPOINT_PATH)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        start_step = ckpt["step"] + 1
        print(f"Resumed from step {start_step}")

    # Training loop
    for step in range(start_step, 200):
        x = torch.randn(32, 16)
        y = torch.randn(32, 1)
        loss = nn.functional.mse_loss(model(x), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % SAVE_EVERY == 0:  # 4
            save_checkpoint(step, model, optimizer)

        # -- 3. A check at the end of each training step -----------------------
        if PREEMPTION_FLAG["flag"]:  # 5
            save_checkpoint(step, model, optimizer)
            requeue()  # 6

    save_checkpoint(step, model, optimizer)
    print("Training complete.")

if __name__ == "__main__":
    main()

How the example works

(See # comment to the right of each component in the script above)

  1. Flag dict - preemption_flag is a dict at module scope so the signal handler can mutate it. The signal handler itself is intentionally minimal: it just sets the flag and returns immediately, keeping signal-handler execution time short.

  2. Signal registration - signal.signal(SIGUSR2, set_preemption_flag) is called once before the loop. Python’s signal handling is cooperative: the handler runs between bytecode instructions, so registering before the loop ensures no signal is missed.

  3. Resuming at startup — loading the checkpoint is unconditional: if the file exists, resume from it. This is idempotent, running the script fresh on a clean directory behaves normally, and every subsequent restart continues from the last saved step.

  4. Periodic savesSAVE_EVERY = 10 guards against unexpected hardware failures that don’t send a signal (e.g., a node going down). Tune this based on how long a save takes vs. how much work you can afford to redo.

  5. Preemption check at step boundary — checking the flag after optimizer.step() ensures the checkpoint captures a consistent state (parameters and optimizer state updated together for the same step).

  6. scontrol requeue — this tells SLURM to re-add the job to the queue with the same job ID and resource request. When the job starts again, it hits the checkpoint-resume logic at startup and continues from where it left off. sys.exit(0) then cleanly terminates the current run.

Batch script configuration

checkpointing_example.sbatch
#!/bin/bash
#SBATCH --job-name=checkpointing_example
##SBATCH --account=<PI_ACCOUNT>    # <-- change to an allowed account on your cluster - RCC CLUSTER ONLY (uncomment if needed)
#SBATCH --partition=<PARTITION>    # <-- change to an allowed partition on your cluster
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=4G
#SBATCH --time=00:10:00
#SBATCH --output=/path/to/logs/%x_%j.out
#SBATCH --error=/path/to/logs/%x_%j.err
#SBATCH --signal=USR2@60        # send SIGUSR2 60s before wall-time kill

python checkpointing_example.py