当前位置: 首页 > news >正文

Day 37 训练

Day 37 训练

  • 深度学习实战:PyTorch 模型训练、保存与早停策略
    • 初始模型训练
    • 过拟合的判断
    • 模型的保存与加载
      • 仅保存模型参数
      • 保存模型 + 权重
      • 保存训练状态(断点续训)
    • 早停法(Early Stopping)
    • 总结


深度学习实战:PyTorch 模型训练、保存与早停策略

在深度学习的征程中,模型的训练、保存与加载是至关重要的环节。本文将通过一个鸢尾花分类任务的实例,带您深入了解如何在 PyTorch 中实现这些关键步骤,同时探讨过拟合的判断与早停法的应用。

初始模型训练

我们从一个简单的多层感知机(MLP)模型开始,使用 PyTorch 对鸢尾花数据集进行分类训练。以下是训练过程的关键代码片段:

# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 加载鸢尾花数据集并划分训练集与测试集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据归一化处理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 定义MLP模型结构
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(4, 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 3)def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# 实例化模型、定义损失函数与优化器
model = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 20000
losses = []
epochs = []start_time = time.time()with tqdm(total=num_epochs, desc="训练进度", unit="epoch") as pbar:for epoch in range(num_epochs):outputs = model(X_train)loss = criterion(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 200 == 0:losses.append(loss.item())epochs.append(epoch + 1)pbar.set_postfix({'Loss': f'{loss.item():.4f}'})if (epoch + 1) % 1000 == 0:pbar.update(1000)time_all = time.time() - start_time
print(f'Training time: {time_all:.2f} seconds')

在训练过程中,我们使用了 tqdm 库来显示训练进度条,使训练过程更加直观。训练完成后,我们得到了一个训练时间为 6.15 秒,测试集准确率为 96.67% 的模型。

过拟合的判断

在训练过程中,我们可能会遇到过拟合的问题,即模型在训练集上表现良好,但在测试集上表现不佳。为了判断是否出现过拟合,我们可以在训练过程中同时记录训练集和测试集的损失值。

# 在训练循环中新增测试集损失计算
if (epoch + 1) % 200 == 0:model.eval()with torch.no_grad():test_outputs = model(X_test)test_loss = criterion(test_outputs, y_test)model.train()train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})

通过绘制训练集和测试集损失曲线,我们可以直观地观察两者的变化趋势。如果训练集损失持续下降,而测试集损失在某一时刻开始上升,就表明模型出现了过拟合。

模型的保存与加载

在实际应用中,我们通常需要保存训练好的模型,以便后续使用或继续训练。PyTorch 提供了多种保存模型的方式:

仅保存模型参数

这种方式只保存模型的权重参数,不保存模型结构代码。加载时需要提前定义与训练时一致的模型类。

# 保存模型参数
torch.save(model.state_dict(), "model_weights.pth")# 加载参数
model = MLP()  # 初始化与训练时相同的模型结构
model.load_state_dict(torch.load("model_weights.pth"))

保存模型 + 权重

这种方式保存模型结构及参数,加载时无需提前定义模型类,但文件体积较大,且依赖训练时的代码环境。

# 保存整个模型
torch.save(model, "full_model.pth")# 加载模型
model = torch.load("full_model.pth")

保存训练状态(断点续训)

为了支持断点续训,我们可以保存模型参数、优化器状态、训练轮次、损失值等完整训练状态。

# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": best_loss,
}
torch.save(checkpoint, "checkpoint.pth")# 加载并续训
model = MLP()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.pth")model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1

早停法(Early Stopping)

早停法是一种有效的防止过拟合的策略。其核心思想是:在训练过程中监控验证集的指标,当指标不再改善时,提前终止训练,避免模型对训练集过度拟合。

以下是应用早停法的代码示例:

# 新增早停相关参数
best_test_loss = float('inf')
best_epoch = 0
patience = 50
counter = 0
early_stopped = False# 在训练循环中新增早停逻辑
if (epoch + 1) % 200 == 0:# ...(其他代码保持不变)if test_loss.item() < best_test_loss:best_test_loss = test_loss.item()best_epoch = epoch + 1counter = 0torch.save(model.state_dict(), 'best_model.pth')else:counter += 1if counter >= patience:print(f"早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。")print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}")early_stopped = Truebreak

在训练过程中,如果测试集损失连续 patience 轮未改善,就会触发早停机制,终止训练。这样可以有效避免模型过拟合,提高模型在测试集上的泛化能力。

总结

通过本文的实例,我们详细介绍了如何在 PyTorch 中进行模型训练、保存与加载,以及如何应用早停法来防止过拟合。这些技术在实际的深度学习项目中具有广泛的应用价值。掌握这些技能,将有助于您更好地构建和优化深度学习模型。
浙大疏锦行

相关文章:

  • 01 Ubuntu20.04下编译QEMU8.2.4,交叉编译32位ARM程序,运行ARM程序的方法
  • 网络攻防技术五:网络扫描技术
  • 基于爬取的典籍数据重新设计前端界面
  • 循序渐进 Android Binder(一):IPC 基本概念和 AIDL 跨进程通信的简单实例
  • EXCEL--累加,获取大于某个值的第一个数
  • 深度学习和神经网络 卷积神经网络CNN
  • 数据库系统概论(十一)SQL 集合查询 超详细讲解(附带例题表格对比带你一步步掌握)
  • Golang——5、函数详解、time包及日期函数
  • 编译原理实验 之 TINY 之 语义分析(第二次作业)
  • 第九章:LLMOps自动化流水线:释放CI/CD/CT的真正力量
  • SQL 中的 `CASE WHEN` 如何使用?
  • AI书签管理工具开发全记录(九):用户端页面集成与展示
  • 排序算法——详解
  • 4.大语言模型预备数学知识
  • 【iOS(swift)笔记-13】App版本不升级时本地数据库sqlite更新逻辑一
  • 企业展示型网站模板HTML5网站模板下载指南
  • PostgreSQL 在生物信息学中的应用
  • Java并发编程实战 Day 4:线程间通信机制
  • 网络节点排查
  • cpper 转 Golang
  • 怎么做网站的301/谷歌搜索引擎怎么才能用
  • 温岭高端网站设计哪家好/网站seo诊断分析
  • 福建省住房建设厅网站6/郑州网站建设哪家好
  • 域名和主机搭建好了怎么做网站/网站建立
  • 外包一个企业网站多少钱/宁德市医院东侨院区
  • 含山县城市建设有限公司网站/企业门户网站模板