PyTorch模型 train() 和 eval() 模式详解
在 PyTorch 中,train()
和 eval()
是模型两种关键的工作模式,正确使用它们对模型训练和推理至关重要。
一、两种模式的核心区别
模式 | 影响模块 | 主要用途 | Dropout | BatchNorm | 梯度计算 |
---|---|---|---|---|---|
train() |
所有训练相关层 | 模型训练 | 激活 | 使用batch统计 | 保留梯度 |
eval() |
所有训练相关层 | 模型验证/测试 | 关闭 | 使用运行统计 | 关闭梯度计算 |
二、典型使用场景
1. 训练阶段
model.train() # 设置为训练模式
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
2. 验证/测试阶段
model.eval() # 设置为评估模式
with torch.no_grad(): # 关闭梯度计算
for data, target in val_loader:
output = model(data)
loss = criterion(output, target