MNIST 数据并行 Data Parallel - DP
Data Parallel
转自我的个人博客:https://shar-pen.github.io/2025/05/04/torch-distributed-series/2.MNIST_DP/
数据并行 vs. 模型并行
-
数据并行:模型拷贝(per device),数据 split/chunk(batch 上)
- the module is replicated on each device, and each replica handles a portion of the input.
- During the backwards pass, gradients from each replica are summed into the original module.
-
模型并行:数据拷贝(per device),模型 split/chunk(显然是单卡放不下模型的情况下)
DP 说的直白点,就是把通常意义上的 batch 切分到各个卡上, 在主卡的控制下实现前向传播和反向传播。
官方文档: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
参数:
- module, 需要 DP 的 torch.nn.module, i.e., 你的 model
- device_ids=None, 参与训练的 GPU 有哪些,device_ids=gpus
- output_device=None, 用于汇总梯度的 GPU 是哪个,output_device=gpus[0]
- dim=0, 数据切分的维度, 一般就是第一维度 batch 维度来切分, dim = 0 [30, xxx] -> [10, …], [10, …], [10, …] on 3 GPUs
The parallelized module must have its parameters and buffers on device_ids[0] before running(forward/backward) this DataParallel module. 模型参数必须先缓存在给定卡中的第一张卡上。
如果我只执行了 nn.DataParallel, 没有执行 model.to(device) device 必须是选中的第一张卡, 否则会报错 RuntimeError: module must have its parameters and buffers on device cuda:4 (device_ids[0]) but found one of them on device: cpu. PS: 你没先缓存到正确的卡上
class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, 3, 1),nn.ReLU(),nn.Conv2d(32, 64, 3, 1),nn.ReLU(),nn.MaxPool2d(2),nn.Dropout(0.25))self.classifier = nn.Sequential(nn.Linear(9216, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10))def forward(self, input):x = self.features(input)x = torch.flatten(x, 1)x = self.classifier(x)output = F.log_softmax(x, dim=1)# 输出 forward 变量形状,主要关注 batch 大小print(f"[Inside]: input shape: {input.size()}, label shape: {output.size()}")return output
单卡 forward
device = torch.device(f"cuda:4" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model = model.to(device)
for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)print(f"[Outside]: input shape: {data.size()}, label shape: {target.size()}")output = model(data)break
DP forward
CUDA_DEVICE_IDS = [4,5]
device = torch.device(f"cuda:{CUDA_DEVICE_IDS[0]}" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model = model.to(device)
model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)
for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)output = model(data)print(f"[Outside]: input shape: {data.size()}, label shape: {target.size()}, output shape: {output.size()}")break
forward 参数对比 (单卡和DP-2卡)
单卡
[Outside]: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512])
[Inside]: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512, 10])DP-2卡
[Inside]: input shape: torch.Size([256, 1, 28, 28]), label shape: torch.Size([256, 10])
[Inside]: input shape: torch.Size([256, 1, 28, 28]), label shape: torch.Size([256, 10])
[Outside]: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512]), output shape: torch.Size([512, 10])
用DP情况下虽然循环里每次的 batch 大小还是一样的, 但模型 forward 确实将 batch / len(device_ids), 原来 512 的 batch 变为 256, 两个卡上各自有一个模型分别跑了 1 / len(device_ids) 的数据。
单卡
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 On | 00000000:81:00.0 Off | Off |
| 46% 33C P8 14W / 450W | 719MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 On | 00000000:A1:00.0 Off | Off |
| 43% 32C P8 20W / 450W | 4MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+DP-2卡
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 On | 00000000:81:00.0 Off | Off |
| 45% 34C P2 55W / 450W | 717MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 On | 00000000:A1:00.0 Off | Off |
| 43% 33C P2 57W / 450W | 719MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
每个卡上都有一份模型参数和 batch / len(device_ids) 的数据
DP 训练
实际只是增加了 model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)
, 后续 forward 和 backpropagation 都不需要改变
CUDA_DEVICE_IDS = [4,5]
DEVICE = torch.device(f"cuda:{CUDA_DEVICE_IDS[0]}" if torch.cuda.is_available() else "cpu")model = ConvNet()
model = model.to(DEVICE)
model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)start_time = time() # Record the start time
for epoch in range(EPOCHS):epoch_start_time = time() # Record the start time of the current epochprint(f'Epoch {epoch}/{EPOCHS}')print(f'Learning Rate: {scheduler.get_last_lr()[0]}')train(model, DEVICE, train_loader, optimizer)test(model, DEVICE, test_loader)scheduler.step()epoch_end_time = time() # Record the end time of the current epochprint(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")end_time = time() # Record the end time
print(f"Total training time: {end_time - start_time:.2f} seconds")
看下 nn.DataParallel 的内部 forward 函数, 有几行代码显示了大致流程
inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) # 分散数据 inputs
replicas = self.replicate(self.module, self.device_ids[: len(inputs)]) # 复制模型 replicas
outputs = self.parallel_apply(replicas, inputs, module_kwargs) # 并行计算 outputs
return self.gather(outputs, self.output_device) # 合并结果 gather
模型输出(outputs)来自多个子 GPU,但会 在主卡上 gather,因为 torch.nn.DataParallel 的默认行为是把所有子 GPU 的输出,gather 回主 GPU(device[0])。
为什么要 gather, 因为 label (targets)通常在主卡上,所以为了计算 loss,需要把输出也 gather 到主卡,才能和 labels 对应。PS:计算损失是要求参数在同一个 device 上。
loss.backward() 触发 autograd,它会根据 gather 的结构把 grad_output 自动 scatter 回子卡,每张子卡用自己的输出执行 backward。最终每个 gpu 的 gradient 都还要进行统一的更新,将梯度聚合再下方梯度,即 all-reduce。
实现 DP 的一种经典编程框架叫 “参数服务器” parameter server,在这个框架里,计算 GPU 称为 Worker,**梯度聚合 GPU 称为 Server。**在实际应用中,为了尽量减少通讯量,一般可选择一个 Worker 同时作为 Server。比如可把梯度全发到 GPU0 上做聚合。DP 的通信瓶颈在于 server 的通信开销,sever 没法一次性立马接受所有数据,所以当 worker 无法传输数据且已经计算完成时,它就只能摸鱼。
!!! 注意:DP 已经不被推荐了
完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time
import argparseclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 32, 3, 1),nn.ReLU(),nn.Conv2d(32, 64, 3, 1),nn.ReLU(),nn.MaxPool2d(2),nn.Dropout(0.25))self.classifier = nn.Sequential(nn.Linear(9216, 128),nn.ReLU(),nn.Dropout(0.5),nn.Linear(128, 10))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)output = F.log_softmax(x, dim=1)return outputdef arg_parser():parser = argparse.ArgumentParser(description="MNIST Training Script")parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")parser.add_argument("--lr_decay_step_num", type=int, default=1, help="Step size for learning rate decay")parser.add_argument("--lr_decay_factor", type=float, default=0.5, help="Factor by which learning rate is decayed")parser.add_argument("--cuda_ids", type=int, nargs='+', default=[0,1], help="List of CUDA device IDs to use")return parser.parse_args()def prepare_data(batch_size):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_data = datasets.MNIST(root = './mnist',train=True, # 设置True为训练数据,False为测试数据transform = transform,# download=True # 设置True后就自动下载,下载完成后改为False即可)train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)test_data = datasets.MNIST(root = './mnist',train=False, # 设置True为训练数据,False为测试数据transform = transform,)test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)return train_loader, test_loaderdef train(model, device, train_loader, optimizer):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if (batch_idx + 1) % 30 == 0: print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))def train_mnist_classification():args = arg_parser()print(args)EPOCHS = args.epochsBATCH_SIZE = args.batch_sizeLR = args.lrLR_DECAY_STEP_NUM = args.lr_decay_step_numLR_DECAY_FACTOR = args.lr_decay_factorCUDA_DEVICE_IDS = args.cuda_idsDEVICE = torch.device(f"cuda:{CUDA_DEVICE_IDS[0]}")train_loader, test_loader = prepare_data(BATCH_SIZE)model = ConvNet().to(DEVICE)model = nn.DataParallel(model, device_ids=CUDA_DEVICE_IDS)optimizer = optim.Adam(model.parameters(), lr=LR)scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)start_time = time() # Record the start timefor epoch in range(EPOCHS):epoch_start_time = time() # Record the start time of the current epochprint(f'Epoch {epoch}/{EPOCHS}')print(f'Learning Rate: {scheduler.get_last_lr()[0]}')train(model, DEVICE, train_loader, optimizer)test(model, DEVICE, test_loader)scheduler.step()epoch_end_time = time() # Record the end time of the current epochprint(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")end_time = time() # Record the end timeprint(f"Total training time: {end_time - start_time:.2f} seconds")if __name__ == "__main__":train_mnist_classification()