农产品信息网站建设方案百度风云榜
在 PyTorch 中,train()
和 eval()
是模型两种关键的工作模式,正确使用它们对模型训练和推理至关重要。
一、两种模式的核心区别
模式 | 影响模块 | 主要用途 | Dropout | BatchNorm | 梯度计算 |
---|---|---|---|---|---|
train() | 所有训练相关层 | 模型训练 | 激活 | 使用batch统计 | 保留梯度 |
eval() | 所有训练相关层 | 模型验证/测试 | 关闭 | 使用运行统计 | 关闭梯度计算 |
二、典型使用场景
1. 训练阶段
model.train() # 设置为训练模式for data, target in train_loader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()
2. 验证/测试阶段
model.eval() # 设置为评估模式with torch.no_grad(): # 关闭梯度计算for data, target in val_loader:output = model(data)loss = criterion(output, target