Day 37 训练
Day 37 训练
- 深度学习实战:PyTorch 模型训练、保存与早停策略
- 初始模型训练
- 过拟合的判断
- 模型的保存与加载
- 仅保存模型参数
- 保存模型 + 权重
- 保存训练状态(断点续训)
- 早停法(Early Stopping)
- 总结
深度学习实战:PyTorch 模型训练、保存与早停策略
在深度学习的征程中,模型的训练、保存与加载是至关重要的环节。本文将通过一个鸢尾花分类任务的实例,带您深入了解如何在 PyTorch 中实现这些关键步骤,同时探讨过拟合的判断与早停法的应用。
初始模型训练
我们从一个简单的多层感知机(MLP)模型开始,使用 PyTorch 对鸢尾花数据集进行分类训练。以下是训练过程的关键代码片段:
# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 加载鸢尾花数据集并划分训练集与测试集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据归一化处理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 定义MLP模型结构
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型、定义损失函数与优化器
model = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 20000
losses = []
epochs = []start_time = time.time()with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:for epoch in range(num_epochs):outputs = model(X_train)loss = criterion(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 200 == 0:losses.append(loss.item())epochs.append(epoch + 1)pbar.set_postfix({'Loss': f'{loss.item():.4f}'})if (epoch + 1) % 1000 == 0:pbar.update(1000)time_all = time.time() - start_time
print(f'Training time: {time_all:.2f} seconds')
在训练过程中,我们使用了 tqdm 库来显示训练进度条,使训练过程更加直观。训练完成后,我们得到了一个训练时间为 6.15 秒,测试集准确率为 96.67% 的模型。
过拟合的判断
在训练过程中,我们可能会遇到过拟合的问题,即模型在训练集上表现良好,但在测试集上表现不佳。为了判断是否出现过拟合,我们可以在训练过程中同时记录训练集和测试集的损失值。
# 在训练循环中新增测试集损失计算
if (epoch + 1) % 200 == 0:model.eval()with torch.no_grad():test_outputs = model(X_test)test_loss = criterion(test_outputs, y_test)model.train()train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})
通过绘制训练集和测试集损失曲线,我们可以直观地观察两者的变化趋势。如果训练集损失持续下降,而测试集损失在某一时刻开始上升,就表明模型出现了过拟合。
模型的保存与加载
在实际应用中,我们通常需要保存训练好的模型,以便后续使用或继续训练。PyTorch 提供了多种保存模型的方式:
仅保存模型参数
这种方式只保存模型的权重参数,不保存模型结构代码。加载时需要提前定义与训练时一致的模型类。
# 保存模型参数
torch.save(model.state_dict(), "model_weights.pth")# 加载参数
model = MLP() # 初始化与训练时相同的模型结构
model.load_state_dict(torch.load("model_weights.pth"))
保存模型 + 权重
这种方式保存模型结构及参数,加载时无需提前定义模型类,但文件体积较大,且依赖训练时的代码环境。
# 保存整个模型
torch.save(model, "full_model.pth")# 加载模型
model = torch.load("full_model.pth")
保存训练状态(断点续训)
为了支持断点续训,我们可以保存模型参数、优化器状态、训练轮次、损失值等完整训练状态。
# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": best_loss,
}
torch.save(checkpoint, "checkpoint.pth")# 加载并续训
model = MLP()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.pth")model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
早停法(Early Stopping)
早停法是一种有效的防止过拟合的策略。其核心思想是:在训练过程中监控验证集的指标,当指标不再改善时,提前终止训练,避免模型对训练集过度拟合。
以下是应用早停法的代码示例:
# 新增早停相关参数
best_test_loss = float('inf')
best_epoch = 0
patience = 50
counter = 0
early_stopped = False# 在训练循环中新增早停逻辑
if (epoch + 1) % 200 == 0:# ...(其他代码保持不变)if test_loss.item() < best_test_loss:best_test_loss = test_loss.item()best_epoch = epoch + 1counter = 0torch.save(model.state_dict(), 'best_model.pth')else:counter += 1if counter >= patience:print(f"早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}")early_stopped = Truebreak
在训练过程中,如果测试集损失连续 patience
轮未改善,就会触发早停机制,终止训练。这样可以有效避免模型过拟合,提高模型在测试集上的泛化能力。
总结
通过本文的实例,我们详细介绍了如何在 PyTorch 中进行模型训练、保存与加载,以及如何应用早停法来防止过拟合。这些技术在实际的深度学习项目中具有广泛的应用价值。掌握这些技能,将有助于您更好地构建和优化深度学习模型。
浙大疏锦行