当前位置: 首页 > news >正文

一学就会的深度学习基础指令及操作步骤(3)模型训练验证

文章目录

    • 模型训练验证
      • 损失函数和优化器
      • 模型优化
      • 训练函数
      • 验证函数
      • 模型保存

模型训练验证

损失函数和优化器

loss_function = nn.CrossEntropyLoss() # 损失函数
optimizer = Adam(model.parameters())  # 优化器,优化参数

模型优化

获得模型所有的可训练参数(比如每一层的权重、偏置),设置优化器类型,自动调整学习步长(自适应学习率),后续训练更新参数。

# 雇佣Adam教练,让他管理模型参数
optimizer = Adam(model.parameters(), lr=0.001)  # lr是初始学习率
# 1. optimizer.zero_grad()    # 清空上一轮的成绩单
# 2. loss.backward()          # 计算每个参数要改进的方向(梯度)
# 3. optimizer.step()         # 参数调整

训练函数

def train():
    loss = 0
    accuracy = 0

    model.train()
    for x, y in train_loader:  # 获得每个batch数据
        x, y = x.to(device), y.to(device)
        output = model(x)        # 得到预测label
        optimizer.zero_grad()    # 梯度清零
        batch_loss = loss_function(output, y)  # 计算batch误差
        batch_loss.backward()    # 计算误差梯度
        optimizer.step()         # 调整模型参数

        loss += batch_loss.item()
        accuracy += get_batch_accuracy(output, y, train_N)
    print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

验证函数

def validate():
    loss = 0
    accuracy = 0

    model.eval() # 评估模式,关闭随机性等增加稳定性
    with torch.no_grad(): # 禁用梯度,提高效率
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            output = model(x)  
            # 不用进行梯度计算、参数调整
            loss += loss_function(output, y).item()
            accuracy += get_batch_accuracy(output, y, valid_N)
    print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

模型保存

.pth 文件是PyTorch模型的“存档文件”,保存了所有必要信息。加载后,模型即可直接运行,无需重新训练!

# 保存整个模型(结构 + 参数)
torch.save(model, 'model.pth')

.pth 文件可以用https://netron.app/查看

相关文章:

  • FPGA|Verilog-自己写的SPI驱动
  • 【测试框架篇】单元测试框架pytest(4):assert断言详解
  • SpringBoot(1)——创建SpringBoot项目的方式
  • 【Vue3】详细探究 watch ref 数组不生效的问题
  • LeetCode 2380 二进制字符串重新安排顺序需要的时间
  • 无人机楼宇间物资运输技术详解
  • 【算法 C/C++】二维前缀和
  • 【密码学——基础理论与应用】李子臣编著 第三章 分组密码 课后习题
  • mysql的MGR
  • 在mac中设置环境变量
  • 校验pytorch是否支持显卡GPU 不支持卸载并安装支持版本
  • 报表控件stimulsoft操作:使用 Angular 应用程序的报告查看器组件
  • ngx_openssl_create_conf
  • Zookeeper实践指南
  • BI 工具响应慢?可能是 OLAP 层拖了后腿
  • 【报错】微信小程序预览报错”60001“
  • unity使用mesh 画图(1)
  • Spring 事务和事务传播机制
  • 接口测试笔记
  • C语言(23)
  • 日月谭天 | 赖清德倒行逆施“三宗罪”,让岛内民众怒不可遏
  • 习近平:坚持科学决策民主决策依法决策,高质量完成“十五五”规划编制工作
  • “80后”北大硕士罗婕履新甘肃宁县县委常委、组织部部长
  • 因救心梗同学缺席职教高考的姜昭鹏顺利完成补考
  • 朱雀二号改进型遥二运载火箭发射成功
  • “80后”萍乡市安源区区长邱伟,拟任县(区)委书记