In this blog post Best Practices for Loading and Saving PyTorch Weights in Production we will map out the practical ways to persist and restore your models without surprises. Whether you build models or manage teams shipping them, understanding how PyTorch saves weights is essential to reproducibility, speed, and safety.
At a high level, PyTorch models are just Python objects with tensors inside. You rarely want to serialize the whole object; you want the tensors that matter. This post explains the technology behind saving those tensors, shows safe patterns for training and inference, and highlights pitfalls that cause painful production bugs.
How PyTorch stores weights
PyTorch models expose a state_dict()
, a simple mapping of parameter names to tensors (and buffers like running means). This dict is what you should save and load. Under the hood, torch.save
and torch.load
use Python pickle to serialize and deserialize. That’s powerful but means:
- Never load checkpoints from untrusted sources. Pickle can execute arbitrary code.
- Prefer saving state_dicts or lightweight checkpoints, not entire model objects.
- For extra safety in newer PyTorch, use
weights_only=True
when loading compatible files.
The basic patterns you should use
1) Save only what you need
Save a compact checkpoint containing model weights and training state. This keeps files portable and easy to resume.
import torch
# Example: model, optimizer, and (optional) AMP scaler
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
# Save if you use mixed precision
"scaler": scaler.state_dict() if "scaler" in globals() else None,
"epoch": epoch,
"metrics": {"val_loss": val_loss, "val_acc": val_acc},
"pytorch_version": torch.__version__,
}
torch.save(checkpoint, "last.pth")
# For best model snapshot
torch.save(model.state_dict(), "best.pth")
Why two files? last.pth
helps you resume training. best.pth
freezes the best-performing weights for deployment.
2) Load safely on any device
Always map the loaded tensors to the device you plan to use, and use safe loading where available.
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = "last.pth"
# Safe-ish load (PyTorch >= 2.x supports weights_only). Falls back if older.
try:
ckpt = torch.load(path, map_location=device, weights_only=True)
except TypeError:
ckpt = torch.load(path, map_location=device)
model.load_state_dict(ckpt["model"]) # strict by default
optimizer.load_state_dict(ckpt["optimizer"]) # optional if only inferring
start_epoch = int(ckpt.get("epoch", -1)) + 1
print(f"Resuming at epoch {start_epoch}")
If you are loading best.pth
(weights only), do:
state = torch.load("best.pth", map_location=device)
model.load_state_dict(state)
Inference-only loading
For production inference you don’t need the optimizer or scaler. Keep it light and deterministic:
model.load_state_dict(torch.load("best.pth", map_location=device))
model.to(device)
model.eval()
with torch.no_grad():
outputs = model(inputs.to(device))
Saving the best vs the latest
Track validation metrics and capture two artifacts:
- Latest checkpoint for resume (
last.pth
). - Best weights for deployment (
best.pth
).
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "best.pth")
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}, "last.pth")
Transfer learning and partial loads
When your architecture changes (e.g., you swap a classification head), load what matches and ignore the rest:
state = torch.load("pretrained_backbone.pth", map_location="cpu")
# Load non-strictly
incompat = model.load_state_dict(state, strict=False)
print("Missing:", incompat.missing_keys)
print("Unexpected:", incompat.unexpected_keys)
strict=False
is perfect for transfer learning: shared layers load weights; new layers start fresh.
Distributed and DataParallel quirks
If you saved weights from nn.DataParallel
or some DistributedDataParallel
setups, parameter names may have a "module."
prefix. Strip it when loading into a non-wrapped model:
from collections import OrderedDict
raw_state = torch.load("ddp_model.pth", map_location="cpu")
new_state = OrderedDict()
for k, v in raw_state.items():
new_key = k.replace("module.", "", 1) if k.startswith("module.") else k
new_state[new_key] = v
model.load_state_dict(new_state, strict=True)
Tip: when saving under DDP, save model.module.state_dict()
from rank 0 to avoid prefixes and duplicates.
Mixed precision and schedulers
If you train with AMP, also save the GradScaler. If you use a learning-rate scheduler, save its state too.
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(), # if AMP
"scheduler": scheduler.state_dict(), # if used
"epoch": epoch,
}
torch.save(checkpoint, "last.pth")
# Loading
ckpt = torch.load("last.pth", map_location=device)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
if ckpt.get("scaler") is not None:
scaler.load_state_dict(ckpt["scaler"])
if ckpt.get("scheduler") is not None:
scheduler.load_state_dict(ckpt["scheduler"])
Device portability tips
- map_location lets you load GPU-saved checkpoints on CPU-only hosts.
- For fully portable files, you can save CPU tensors explicitly:
cpu_state = {k: v.cpu() for k, v in model.state_dict().items()}
torch.save(cpu_state, "model_cpu.pth")
- Move to the right device after loading:
model.to(device)
.
Security and compatibility notes
- Security: Only load from trusted sources. If using PyTorch 2.x and you saved only weights, prefer
torch.load(..., weights_only=True)
. - Don’t save full objects: Avoid
torch.save(model)
. It ties you to Python code structure and increases risk when loading. - Versioning: Record your PyTorch and CUDA versions in the checkpoint metadata; it helps when reproducing environments.
- Alternatives: For zero-pickle formats, consider the
safetensors
ecosystem if your stack supports it.
Common mistakes and how to avoid them
- “Size mismatch” on load: Architecture changed. Use
strict=False
and update heads, or align layer names/shapes. - Slower or unstable resumed training: You forgot optimizer/scheduler states. Save and load them with the model.
- CUDA errors on CPU hosts: Omit
map_location
. Always set it when loading on a different device. - Odd parameter names with
module.
: Strip the prefix when switching away from DDP/DataParallel. - Non-deterministic results: Fix random seeds and keep library versions consistent when validating reproducibility.
Minimal end-to-end skeleton
import torch, torch.nn as nn, torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Net(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10)
)
def forward(self, x):
return self.net(x)
model = Net().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
# Train loop (toy)
for epoch in range(10):
model.train()
# ... compute loss
loss = torch.tensor(0.0, device=device) # placeholder
optimizer.zero_grad(); loss.backward(); optimizer.step()
# Validate and save
val_loss = 0.1 # placeholder
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}, "last.pth")
# Inference
state = torch.load("last.pth", map_location=device)
model.load_state_dict(state["model"])
model.eval()
with torch.no_grad():
preds = model(torch.randn(1, 32, device=device))
Copy-and-paste checklist
- Save
model.state_dict()
and a compact training state as needed. - Keep a separate
best.pth
for deployment. - Use
map_location
on load; callmodel.to(device)
after. - Resume training by loading optimizer, scheduler, scaler, and
epoch
. - Transfer learning:
strict=False
, review missing/unexpected keys. - DDP: save from rank 0; strip
module.
when needed. - Security: avoid
torch.save(model)
; prefer weights, considerweights_only=True
. - Record versions and key metrics in the checkpoint.
Handled well, loading and saving PyTorch weights is boring—in the best way. Use these patterns to make your experiments repeatable, your deployments reliable, and your handovers painless.
Discover more from CPI Consulting
Subscribe to get the latest posts sent to your email.