当前位置: 首页 > news >正文

torch 中 model.eval() 和 model.train() 的作用

在 PyTorch 中,model.train() 的作用是将模型切换到训练模式(training mode),主要影响模型中某些特定层(如 Dropout 和 BatchNorm)的行为,使其在训练时启用随机性和动态统计量。以下是详细说明:

核心功能

  1. 启用随机性层:
    • Dropout:在训练模式下,按设定的概率随机丢弃神经元(防止过拟合)。
    • BatchNorm:使用当前 batch 的均值和方差进行归一化,并更新移动平均统计量(用于后续的评估模式)。
  2. 确保训练时的动态行为:训练模式下,模型的输出依赖于当前输入数据的随机性(如 Dropout)和动态统计量(如 BatchNorm),这对模型学习特征至关重要。

model.eval()的作用是将模型切换到评估模式(evaluation mode),主要影响模型中某些特定层(如Dropout和BatchNorm)的行为,使其在推理(测试)时表现一致且稳定。(实际上不使用dropout,model.eval()对这些不开dropout的大模型实际上没有影响。):

核心功能

  1. 关闭随机性层
    • Dropout:在训练时随机丢弃神经元以防止过拟合,但在评估模式下会保留所有神经元
    • BatchNorm:在训练时使用当前batch的均值和方差进行归一化,并更新移动平均统计量;在评估模式下,则使用训练阶段累积的全局均值和方差,而非当前batch的数据。
  2. 确保输出稳定性:评估模式下,模型的输出仅依赖训练好的参数,避免因随机性(如Dropout)或统计量波动(如BatchNorm)导致测试结果不稳定。

model.eval()为什么需要配合torch.no_grad()

model.eval()仅改变模型层的行为,而torch.no_grad()会禁用梯度计算,减少内存占用并加速推理。
通常在测试时同时使用两者:

推理阶段代码示例

model.eval()
with torch.no_grad():
    outputs = model(inputs)

Torch中模型训练评估模式演示代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# 定义模型(添加激活函数)
MyModel = nn.Sequential(
    nn.Linear(10, 20),
    nn.BatchNorm1d(20),# 训练时使用当前batch的统计量
    nn.ReLU(),
    nn.Dropout(0.5)     # 训练时随机丢弃50%的神经元
)

# 示例数据(假设是分类任务)
X_train = torch.randn(1000, 10)
y_train = torch.randint(0, 20, (1000,))
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=64)

X_test = torch.randn(200, 10)
y_test = torch.randint(0, 20, (200,))
test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=64)

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(MyModel.parameters(), lr=0.001)

model = MyModel

# 训练阶段
model.train()
for epoch in range(10):
    for data, targets in train_loader:
        optimizer.zero_grad()  # 重置梯度
        outputs = model(data)
        loss = loss_fn(outputs, targets)
        loss.backward()  #计算损失函数关于模型参数的梯度
        optimizer.step() #更新网络的权重和偏置等参数。


def calculate_accuracy(outputs, targets):
    preds = outputs.argmax(dim=1)
    correct = (preds == targets).sum().item()
    return correct / len(targets)

# 评估阶段
model.eval()
with torch.no_grad(): # 禁用梯度计算
    total_accuracy = 0
    for data, targets in test_loader:
        outputs = model(data)
        accuracy = calculate_accuracy(outputs, targets)
        total_accuracy += accuracy
    print(f"Test Accuracy: {total_accuracy / len(test_loader):.4f}")

相关文章:

  • 肿瘤检测新突破:用随机森林分类器助力医学诊断
  • 【JAVA架构师成长之路】【Redis】第18集:Redis实现分布式高并发加减计数器
  • 小程序事件系统 —— 33 事件传参 - data-*自定义数据
  • AI视频生成工具清单(附网址与免费说明)
  • 支持向量机的深度解析:从理论到C++实现
  • 栈概念和结构
  • Linux安装RabbitMQ
  • 01-简单几步!在Windows上用llama.cpp运行DeepSeek-R1模型
  • 计算机毕业设计SpringBoot+Vue.js制造装备物联及生产管理ERP系统(源码+文档+PPT+讲解)
  • FFmpeg-chapter7和chapter8-使用 FFmpeg 解码视频(原理篇和实站篇)
  • 【2024_CUMCM】图论模型
  • SwanLab简明教程:从萌新到高手
  • NO.30十六届蓝桥杯备战|C++输入输出|单组测试用例|多组测试用例|isalpha|逗号表达式(C++)
  • C语言-语法
  • vocabulary is from your listening,other speaking and your thought.
  • 如何借助人工智能AI模型开发一个类似OpenAI Operator的智能体实现电脑自动化操作?
  • 什么是美颜SDK?从几何变换到深度学习驱动的美颜算法详解
  • 免费AI图片生成工具推荐
  • 005-获取内存占用率
  • C运算符 对比a++、++a、b--、 --b
  • 做传销网站违法吗/北仑seo排名优化技术
  • 阿里免费做网站/友情链接交换平台
  • 做网站看什么书好/深圳网站设计专家乐云seo
  • 南昌网站开发技术/网页模板素材
  • 网站程序模板下载/自己如何制作一个网页
  • 澳门网站做推广违法吗/学网络运营需要多少钱