PyTorch模型保存方式
PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。
1. 仅保存模型参数(推荐)
# 保存
torch.save(model.state_dict(), "model_params.pth") # 加载
new_model = TheModelClass()
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()
核心优势:
-
文件体积小(仅参数数据)
-
避免PyTorch版本兼容问题
-
支持跨模型结构迁移(需设置
strict=False
)
2. 保存完整模型对象
# 保存
torch.save(model, "full_model.pth") # 加载 loaded_model = torch.load("full_model.pth")
loaded_model.eval()
适用场景:
-
快速原型验证
-
模型结构包含动态逻辑(如自定义前向传播)
3. 训练断点保存与恢复
# 保存检查点
checkpoint = {'epoch': current_epoch,'model_state': model.state_dict(),'optimizer_state': optimizer.state_dict(),'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train() # 保持训练模式
关键细节:
-
推荐使用
.tar
后缀区分普通参数文件 -
自动恢复学习率调度器等训练状态