PyTorch 中 model.eval() 的使用与作用详解

文章目录
- 一、🚀`model.train()` 与 `model.eval()` 是什么?
- 二、为什么需要 `model.eval()`
- 三、`model.eval()` 与 `torch.no_grad()` 的区别
- 四、完整示例:对比 `train()` 和 `eval()`
- 五、与 `model.train()` 的区别总结
- 六、完整实战代码(训练 + 验证)
- 七、常见错误与避坑指南
- 八、小结
- 九、📚 参考资料
一、🚀model.train() 与 model.eval() 是什么?
使用 PyTorch 进行深度学习训练时,我们经常会看到如下的代码片段:
model.train()
# 训练阶段...model.eval()
# 验证或测试阶段...
很多初学者第一次看到时都会问:
“为什么要在测试前加一句
model.eval()?
不加行不行?到底起了什么作用?”
eval,英文意即为评估

在 PyTorch 中,每个神经网络模型都是一个 nn.Module 的子类。
而 nn.Module 中有两个非常重要的模式:
| 模式 | 含义 | 常用于 |
|---|---|---|
model.train() | 开启训练模式(默认) | 模型训练阶段 |
model.eval() | 开启评估模式 | 验证、测试阶段 |
👉 它们的区别不在于是否计算梯度,
而在于模型内部某些层(如 Dropout、BatchNorm)的行为发生变化。
二、为什么需要 model.eval()
神经网络中有些层在“训练”和“推理”阶段需要不同的行为,例如:
Dropout 层
- 在训练时,会随机“丢弃”一部分神经元(防止过拟合);
- 在测试时,则应该关闭 Dropout,让所有神经元都参与计算。
如果你不调用 model.eval(),
那在测试阶段 Dropout 仍然会随机丢弃神经元,导致结果不稳定、性能下降。
Batch Normalization 层(BN层)
- 在训练时,BatchNorm 会根据当前 mini-batch 的均值和方差进行标准化;
- 在测试时,应该使用在训练中统计到的“全局均值和方差”来规范化。
如果不切换到 eval 模式,
BN 层会继续更新统计信息,导致推理结果偏差甚至错误。
✅ 结论:
model.eval()的核心作用是让模型中某些层(Dropout、BatchNorm)进入“推理模式”。
三、model.eval() 与 torch.no_grad() 的区别
这两个经常一起出现,很多人容易混淆:
| 功能 | 是否影响 Dropout/BN | 是否停止计算梯度 | 使用场景 |
|---|---|---|---|
model.eval() | ✅ 是 | ❌ 否 | 切换模型状态(推理模式) |
torch.no_grad() | ❌ 否 | ✅ 是 | 禁止梯度计算,加快推理速度、节省显存 |
因此,推理时我们通常会这样写👇:
model.eval() # 切换为推理模式
with torch.no_grad(): # 不计算梯度outputs = model(inputs)
四、完整示例:对比 train() 和 eval()
让我们用一个小例子直观看看区别 👇
import torch
import torch.nn as nn# 一个简单的网络,包含 Dropout
class SimpleNet(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(4, 4)self.dropout = nn.Dropout(p=0.5)def forward(self, x):return self.dropout(self.fc(x))# 创建模型和输入
x = torch.ones(4)
model = SimpleNet()# 训练模式
model.train()
print("Train Mode Output:")
for _ in range(3):print(model(x))# 推理模式
model.eval()
print("\nEval Mode Output:")
for _ in range(3):print(model(x))
输出对比
Train Mode Output:
tensor([-0.0000, -1.4387, 0.7793, 0.0000], grad_fn=<MulBackward0>)
tensor([-0.0000, -1.4387, 0.0000, 0.0000], grad_fn=<MulBackward0>)
tensor([-0.0000, -1.4387, 0.7793, 0.0000], grad_fn=<MulBackward0>)Eval Mode Output:
tensor([-0.2442, -0.7194, 0.3897, 0.9389], grad_fn=<ViewBackward0>)
tensor([-0.2442, -0.7194, 0.3897, 0.9389], grad_fn=<ViewBackward0>)
tensor([-0.2442, -0.7194, 0.3897, 0.9389], grad_fn=<ViewBackward0>)
✅ 说明:
- 训练模式下 Dropout 随机屏蔽神经元,因此每次输出不同;
- 推理模式下 Dropout 被关闭,输出稳定。
五、与 model.train() 的区别总结
| 比较项 | model.train() | model.eval() |
|---|---|---|
| 模型状态 | 训练模式 | 推理模式 |
| Dropout | 启用随机丢弃 | 关闭 |
| BatchNorm | 使用批次统计 | 使用全局统计 |
| 是否影响梯度 | ❌ 不影响 | ❌ 不影响 |
| 常用场景 | 模型训练阶段 | 验证、推理阶段 |
六、完整实战代码(训练 + 验证)
import torch
import torch.nn as nn
import torch.optim as optim# 定义简单模型
class Net(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(4, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)self.fc2 = nn.Linear(10, 3)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)return self.fc2(x)model = Net()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()for epoch in range(3):# ===== 训练阶段 =====model.train()optimizer.zero_grad()x = torch.randn(5, 4)y = torch.randint(0, 3, (5,))out = model(x)loss = criterion(out, y)loss.backward()optimizer.step()# ===== 验证阶段 =====model.eval()with torch.no_grad():val_x = torch.randn(5, 4)val_out = model(val_x)val_pred = val_out.argmax(dim=1)print(f"Epoch {epoch}: loss={loss.item():.4f}, val_pred={val_pred.tolist()}")
输出如下:
Epoch 0: loss=1.0044, val_pred=[1, 1, 1, 2, 2]
Epoch 1: loss=0.9953, val_pred=[2, 1, 2, 2, 2]
Epoch 2: loss=1.2143, val_pred=[2, 2, 1, 2, 1]
✅ 训练时:
- Dropout 启用;
- BatchNorm 统计更新。
✅ 验证时:
- Dropout 关闭;
- BatchNorm 使用训练统计参数。
七、常见错误与避坑指南
| 错误用法 | 后果 |
|---|---|
在测试时忘记 model.eval() | Dropout、BN 层仍随机,导致结果波动、不稳定 |
在推理时忘记 torch.no_grad() | 会记录梯度,浪费显存、速度变慢 |
在训练时调用了 model.eval() | 模型学不动,BN 不更新统计信息 |
忘记在训练开始前加 model.train() | 模型仍在推理模式,训练效果不佳 |
八、小结
| 项目 | 说明 |
|---|---|
| 函数名 | model.eval() |
| 所属模块 | torch.nn.Module |
| 作用 | 切换模型到评估(推理)模式 |
| 影响层 | Dropout、BatchNorm |
与 no_grad 区别 | eval() 控制模式,no_grad 控制梯度 |
| 使用场景 | 验证、测试、推理阶段 |
| 常用组合 | model.eval() + with torch.no_grad(): |
九、📚 参考资料
- PyTorch 官方文档:torch.nn.Module.eval()
- 快速学pytorch之评估模式:model.eval()-CSDN博客
- 预测时一定要记得model.eval()!
- pytorch中的model. train()和model. eval()到底做了什么?
