PyTorch系列教程:高效保存和加载PyTorch模型
PyTorch是一个强大的深度学习库,广泛用于构建和训练神经网络。PyTorch的优点之一是它提供了简单有效的函数来保存和加载模型。这对于恢复训练、共享模型或部署模型进行推理等任务特别有用。
保存PyTorch模型
torch.save() 函数用于序列化模型并将其保存到磁盘。这个过程很简单,但是对torch.save()的特性有一个很好的理解将有助于你有效地管理你保存的模型。
保存PyTorch模型的通用语法包括两部分:模型的状态字典和推荐的文件格式,通常以.pt或.pth作为扩展名。
import torch
import torch.nn as nn
import torch.optim as optim
# Example model
eclass SampleModel(nn.Module):
def __init__(self):
super(SampleModel, self).__init__()
self.layer = nn.Linear(10, 2)
def forward(self, x):
return self.layer(x)
model = SampleModel()
# Define a random optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Specify the path for saving the model
PATH = "./model.pth"
# Save the model's state_dict (recommended approach)
torch.save(model.state_dict(), PATH)
这里,state_dict是一个Python字典对象,它存储模型的所有参数和持久缓冲区(例如,运行任何批处理规范层的平均值)。
加载PyTorch模型
torch.load() 函数用于加载已保存的PyTorch模型。要记住的关键一点是,将state_dict 加载到模型中需要使用与保存时相同的体系结构实例化模型类。
# Initialize the model again
model = SampleModel()
# Load the weights from the saved file into the model's state_dict
model.load_state_dict(torch.load(PATH))
# Remember to set the model to evaluation mode after loading
model.eval()
在此代码片段中,将权重加载到模型中之后,调用model.eval()将模型设置为求值模式至关重要。如果模型包含诸如dropout层或批处理规范化层之类的层,这一点尤其重要,因为它们在评估和训练期间的行为不同。
保存完整模型
如果你不仅想保存state_dict,还想保存整个模型,包括体系结构,可以使用torch.save()。但是,一定要小心,因为保存整个模型会创建一个很大的文件大小,并引入依赖关系
# This saves the entire model including the architecture
torch.save(model, "./full_model.pth")
# To load the full model:
model = torch.load("./full_model.pth")
model.eval()
如果你计划重构代码并仍然访问模型权重,则通常不建议使用这种方法,因为加载将使用保存时的代码。与state_dict方法相比,它提供的灵活性较差。
最后总结
在PyTorch中保存和加载模型可以很简单,但需要了解何时使用状态字典而不是整个模型。当你希望干净地存储参数值并保持代码效率时,请使用state_dicts进行保存和加载。只有在必要的情况下保留体系结构时才选择保存整个模型,并认识到所涉及的权衡。无论你是在训练阶段之间进行转换,还是在进行推理部署,PyTorch 的序列化和反序列化都能为你提供所需的必备工具。