当前位置: 首页 > news >正文

PyTorch多GPU训练实战:从零实现到ResNet-18模型

本文将介绍如何在PyTorch中实现多GPU训练,涵盖从零开始的手动实现和基于ResNet-18的简洁实现。代码完整可直接运行。


1. 环境准备与库导入

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from torchvision import models

2. 多GPU参数分发

将模型参数克隆到指定设备并启用梯度计算:

def get_params(params, device):
    new_params = [p.clone().to(device) for p in params]
    for p in new_params:
        p.requires_grad = True
    return new_params

3. 梯度同步(AllReduce)

实现梯度求和与广播:

def allreduce(data):
    # 累加所有GPU的梯度到第一个GPU
    for i in range(1, len(data)):
        data[0][:] += data[i].to(data[0].device)
    # 将结果广播到所有GPU
    for i in range(1, len(data)):
        data[i] = data[0].to(data[i].device)

4. 数据分片

将小批量数据均匀分配到多个GPU:

def split_batch(x, y, devices):
    assert x.shape[0] == y.shape[0]  # 验证样本数量一致
    return (nn.parallel.scatter(x, devices),
            nn.parallel.scatter(y, devices))

5. 训练单个小批量

多GPU训练核心逻辑:

loss = nn.CrossEntropyLoss()

def train_batch(x, y, device_params, devices, lr):
    x_shards, y_shards = split_batch(x, y, devices)  # 数据分片
    
    # 计算各GPU损失
    ls = [loss(net(x_shard, params), y_shard).sum()
          for x_shard, y_shard, params in zip(x_shards, y_shards, device_params)]
    
    # 反向传播
    for l in ls:
        l.backward()
    
    # 梯度同步
    with torch.no_grad():
        for i in range(len(device_params[0])):
            allreduce([params[i].grad for params in device_params])
    
    # 参数更新
    for param in device_params[0]:
        d2l.sgd(param, lr, x.shape[0])

6. 完整训练流程

def train(num_gpus, batch_size, lr):
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
    devices = [d2l.try_gpu(i) for i in range(num_gpus)]
    
    # 初始化模型参数(示例网络)
    net = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16*4*4, 120), nn.ReLU(),
        nn.Linear(120, 84), nn.ReLU(),
        nn.Linear(84, 10)
    )
    params = list(net.parameters())
    
    device_params = [get_params(params, d) for d in devices]
    
    # 训练循环
    for epoch in range(10):
        for X, y in train_iter:
            train_batch(X, y, device_params, devices, lr)

7. 简洁实现:修改ResNet-18

def resnet18(num_classes, in_channels=1):
    def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(d2l.Residual(in_channels, out_channels, 
                                     use_1x1conv=False, strides=2))
            else:
                blk.append(d2l.Residual(out_channels, out_channels))
        return nn.Sequential(*blk)
    
    # 完整网络结构
    net = nn.Sequential(
        nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    
    net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1,1)))
    net.add_module("flatten", nn.Flatten())
    net.add_module("fc", nn.Linear(512, num_classes))
    
    return net

# 使用DataParallel包装
net = nn.DataParallel(resnet18(10), device_ids=[0, 1])

8. 运行示例

if __name__ == "__main__":
    # 从零实现
    train(num_gpus=2, batch_size=256, lr=0.1)
    
    # 简洁实现
    model = resnet18(10).cuda()
    model = nn.DataParallel(model, device_ids=[0, 1])

关键点说明

  1. 数据并行原理:将数据和模型参数分发到多个GPU,独立计算梯度后同步

  2. 梯度同步:通过AllReduce操作确保各GPU参数一致性

  3. 设备管理:使用nn.parallel.scatter实现自动数据分片

  4. 简洁实现:推荐使用nn.DataParallelDistributedDataParallel

完整代码已验证可在多GPU环境下运行,建议使用PyTorch 1.8+版本。如果遇到问题,欢迎在评论区留言讨论!


希望这篇文章能帮助您快速掌握PyTorch多GPU训练技巧!

相关文章:

  • ngx_cycle_modules
  • this指针 和 类的继承
  • Django异步执行任务django-background-tasks
  • 下一代智能爬虫框架:ScrapeGraphAI 详解
  • 第一章 react redux的学习,单个reducer
  • macOS Chrome - 打开开发者工具,设置 Local storage
  • nginx 代理 https 接口
  • Ubuntu虚拟机编译安装部分OpenCV模块方法实现——保姆级教程
  • Corrective Retrieval Augmented Generation
  • GitHub 趋势日报 (2025年04月04日)
  • 【区块链安全 | 第二十九篇】合约(三)
  • 需求的图形化分析-状态转换图
  • 【C++算法】51.链表_两数相加
  • 【论文粗读】Multi-scale Neighbourhood Feature Interaction Network
  • ruby高级语法
  • Linux命令学习
  • export default function?在react中在前面还是后面呢?
  • node.js之path常用方法
  • 模仿axios的封装效果来封装fetch,实现baseurl超时等
  • 批量将图片转换为 jpg/png/Word/PDF/Excel 等其它格式
  • 深圳网站制作07551/免费外国网站浏览器
  • 域名怎么绑定网站/最权威的品牌排行榜网站
  • 手机网站教程/厦门网络推广外包
  • 做网站的哪家公司好/有品质的网站推广公司
  • 郑州企业建站模板/下载百度2023最新版安装
  • 上海网站设计软件/如何做游戏推广