Training jobs on shared clusters can be interrupted at any time because they hit the wall-time limit (RCC: 36h, DSI: 12h). Checkpointing saves your training state so that when a job restarts, it picks up where it left off rather than starting over.
How SLURM preemption works¶
When SLURM needs to kill your job, it can be configured to send a warning signal first. Adding this line to your batch script tells SLURM to send SIGUSR2 to your process 60 seconds before the kill:
#SBATCH --signal=USR2@60That 60-second window is your opportunity to save a checkpoint and requeue the job before SLURM terminates it.
The pattern¶
The approach used by training frameworks has three components:
1. A preemption flag
A mutable container (a dict) that lives at module scope so a signal handler can modify it:
preemption_flag = dict(flag=False)A plain bool variable won’t work here — assignment inside a nested function creates a new local variable rather than modifying the outer one. A dict is mutable, so preemption_flag["flag"] = True modifies the existing object.
2. A signal handler registered before the training loop
import signal
def set_preemption_flag(signum, frame):
print(f"Signal {signum} received — will checkpoint and requeue.")
preemption_flag["flag"] = True
signal.signal(signal.SIGUSR2, set_preemption_flag)This must be called before the training loop starts. When SLURM sends SIGUSR2, Python interrupts the current instruction, runs set_preemption_flag, then resumes — so the flag is set without crashing the process.
3. A check at the end of each training step
if preemption_flag["flag"]:
save_checkpoint(step)
requeue()Checking at the step boundary (not mid-step) means you always save a clean, consistent state.
Small example¶
The following self-contained script trains a small linear model and implements the full pattern:
# checkpointing_example.py
import os
import sys
import signal
import torch
import torch.nn as nn
from pathlib import Path
CHECKPOINT_PATH = Path("checkpoint.pt")
SAVE_EVERY = 10 # also save every N steps as insurance
# -- 1. Preemption flag --------------------------------------------------------
PREEMPTION_FLAG = dict(flag=False) # module variable # 1
# -- 2. A signal handler registered before the training loop -------------------
def set_preemption_flag(signum, frame):
print(f"Signal {signum} received — will checkpoint and requeue.")
PREEMPTION_FLAG["flag"] = True # closes over the module variable
# Checkpoint and requeue helpers
def save_checkpoint(step, model, optimizer):
torch.save(
{"model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step},
CHECKPOINT_PATH,
)
print(f"Checkpoint saved at step {step}")
def requeue():
job_id = os.environ.get("SLURM_JOB_ID")
if job_id:
print(f"Requeueing job {job_id}")
os.system(f"scontrol requeue {job_id}")
sys.exit(0)
def main():
"""Main execution flow."""
# -- 2. A signal handler registered before the training loop ---------------
signal.signal(signal.SIGUSR2, set_preemption_flag) # 2
# Model and optimizer
model = nn.Linear(16, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
start_step = 0
# Resume checkpoint if one exists
if CHECKPOINT_PATH.exists(): # 3
ckpt = torch.load(CHECKPOINT_PATH)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_step = ckpt["step"] + 1
print(f"Resumed from step {start_step}")
# Training loop
for step in range(start_step, 200):
x = torch.randn(32, 16)
y = torch.randn(32, 1)
loss = nn.functional.mse_loss(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % SAVE_EVERY == 0: # 4
save_checkpoint(step, model, optimizer)
# -- 3. A check at the end of each training step -----------------------
if PREEMPTION_FLAG["flag"]: # 5
save_checkpoint(step, model, optimizer)
requeue() # 6
save_checkpoint(step, model, optimizer)
print("Training complete.")
if __name__ == "__main__":
main()How the example works¶
(See # comment to the right of each component in the script above)
Flag dict -
preemption_flagis adictat module scope so the signal handler can mutate it. The signal handler itself is intentionally minimal: it just sets the flag and returns immediately, keeping signal-handler execution time short.Signal registration -
signal.signal(SIGUSR2, set_preemption_flag)is called once before the loop. Python’s signal handling is cooperative: the handler runs between bytecode instructions, so registering before the loop ensures no signal is missed.Resuming at startup — loading the checkpoint is unconditional: if the file exists, resume from it. This is idempotent, running the script fresh on a clean directory behaves normally, and every subsequent restart continues from the last saved step.
Periodic saves —
SAVE_EVERY = 10guards against unexpected hardware failures that don’t send a signal (e.g., a node going down). Tune this based on how long a save takes vs. how much work you can afford to redo.Preemption check at step boundary — checking the flag after
optimizer.step()ensures the checkpoint captures a consistent state (parameters and optimizer state updated together for the same step).scontrol requeue— this tells SLURM to re-add the job to the queue with the same job ID and resource request. When the job starts again, it hits the checkpoint-resume logic at startup and continues from where it left off.sys.exit(0)then cleanly terminates the current run.
Batch script configuration¶
#!/bin/bash
#SBATCH --job-name=checkpointing_example
##SBATCH --account=<PI_ACCOUNT> # <-- change to an allowed account on your cluster - RCC CLUSTER ONLY (uncomment if needed)
#SBATCH --partition=<PARTITION> # <-- change to an allowed partition on your cluster
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=4G
#SBATCH --time=00:10:00
#SBATCH --output=/path/to/logs/%x_%j.out
#SBATCH --error=/path/to/logs/%x_%j.err
#SBATCH --signal=USR2@60 # send SIGUSR2 60s before wall-time kill
python checkpointing_example.py