Pytorch笔记一之 cpu模型保存、加载与推理
Pytorch笔记一之 cpu模型保存、加载与推理
1.保存模型
首先,在加载模型之前,我们需要了解如何保存模型。PyTorch 提供了两种保存模型的方法:保存整个模型和仅保存模型的状态字典(state dict)。推荐使用第二种方式,因为它更灵活且体积较小。
import torch
import torch.nn as nn# 定义一个简单的神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 实例化模型并进行训练
model = SimpleNN()
# 模型训练过程(省略)# 保存模型的状态字典
torch.save(model.state_dict(), 'simple_nn.pth')
2. 加载模型
一旦你保存了模型,接下来就可以加载它。在加载过程中,确保模型的架构与训练时一致。以下是加载模型的步骤:
- 1.创建一个模型实例
- 2.调用 load_state_dict() 方法加载状态字典
代码示例如下:
# 重新定义模型架构
model = SimpleNN()# 加载模型状态字典
model.load_state_dict(torch.load('simple_nn.pth', map_location=torch.device('cpu')))
3. 在 CPU 上进行推理
完成模型加载后,接下来就可以使用模型进行推理。以下是一个简单的示例:
# 模拟输入数据
input_data = torch.randn(1, 10)# 在 CPU 上进行推理
with torch.no_grad(): # 禁用梯度计算,节省内存output = model(input_data)print(output)