当前位置: 首页 > news >正文

PyTorch构建自定义模型

PyTorch 提供了灵活的方式来构建自定义神经网络模型。下面我将详细介绍从基础到高级的自定义模型构建方法,包含实际代码示例和最佳实践。

一、基础模型构建

1. 继承 nn.Module 基类

所有自定义模型都应该继承 torch.nn.Module 类,并实现两个基本方法:

import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()  # 必须调用父类初始化
        # 在这里定义网络层
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
    
    def forward(self, x):
        # 定义前向传播逻辑
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

2. 模型使用方式 

model = MyModel()
output = model(input_tensor)  # 自动调用forward方法
loss = criterion(output, target)
loss.backward()

二、中级构建技巧

1. 使用 nn.Sequential

nn.Sequential 是一种用于快速构建顺序神经网络的容器类,适用于模块按线性顺序排列的模型。

class MySequentialModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 8 * 8, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

 2. 参数初始化

def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.constant_(m.bias, 0)

model.apply(initialize_weights)  # 递归应用初始化函数

三、高级构建模式

1. 残差连接 (ResNet风格)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

2. 自定义层 

class MyCustomLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))
    
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

 

四、模型保存与加载

1. 保存整个模型

torch.save(model, 'model.pth')  # 保存
model = torch.load('model.pth')  # 加载

2. 保存状态字典 (推荐)

torch.save(model.state_dict(), 'model_state.pth')  # 保存
model.load_state_dict(torch.load('model_state.pth'))  # 加载

五、模型部署准备

1. 模型导出为TorchScript

scripted_model = torch.jit.script(model)  # 或 torch.jit.trace
scripted_model.save('model_scripted.pt')

2. ONNX格式导出

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", 
                 input_names=["input"], output_names=["output"])

六、完整示例:自定义CNN分类器

import torch
from torch import nn
from torch.nn import functional as F

class CustomCNN(nn.Module):
    """自定义CNN图像分类器
    
    Args:
        num_classes (int): 输出类别数
        dropout_prob (float): dropout概率,默认0.5
    """
    def __init__(self, num_classes=10, dropout_prob=0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(128 * 6 * 6, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_prob),
            nn.Linear(512, num_classes)
        )
        
        # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播
        
        Args:
            x (torch.Tensor): 输入张量,形状为[B, C, H, W]
            
        Returns:
            torch.Tensor: 输出logits,形状为[B, num_classes]
        """
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

七、注意事项

  1. 输入输出维度匹配

    • 需确保相邻模块的输入/输出维度兼容。例如,卷积层后接全连接层时需通过 Flatten 或自适应池化调整维度‌。
  2. 调试与验证

    • 可通过模拟输入数据验证模型结构,如:
      input = torch.ones(64, 3, 32, 32)  # 模拟 batch_size=64 的输入
      output = model(input)
      print(output.shape)  # 检查输出形状是否符合预期

       

 

相关文章:

  • 从2G到5G:认证体系演进与网元架构变迁深度解析
  • 使用 iPerf 测试内网两台机器之间的传输速度
  • 2025大唐杯仿真4——信令流程
  • 调用阿里云API实现运营商实名认证
  • 现代科幻赛博朋克风品牌海报电子竞技设计无衬线英文字体 Glander – Techno Font
  • 论文导读 | SOSP23 | Gemini:大模型 内存CheckPoint 快速故障恢复
  • 2025年渗透测试面试题总结-某一线实验室实习扩展(题目+回答)
  • [ctfshow web入门] 零基础版题解 目录(持续更新中)
  • 树莓派5中部署 开源 RF-DETR 实时目标检测模型
  • MySQL窗口函数学习
  • [WUSTCTF2020]CV Maker1
  • k8s 自动伸缩的场景与工作原理
  • Docker Desktop - WSL distro terminated abruptly
  • 关于jdk17安装后没有jre目录的解决办法
  • 机器人轨迹跟踪控制——CLF-CBF-QP
  • Redis-基本数据类型
  • 基于VMware的Cent OS Stream 8安装与配置及远程连接软件的介绍
  • 【7】基础入门篇 | YOLOv8 项目【训练】【验证】【推理】最简单教程 | YOLOv8必看 | 最新更新,直接打印 FPS,mAP50,75,95
  • NXP iMX8MP ARM 平台 EMQX 部署测试
  • C++自学笔记---数组和指针的异同点
  • 做虚拟币网站需要什么手续/北京seo关键词排名
  • 北京网站建设费用/百度网站推广申请
  • 衡水哪个公司做网站好/网络舆情报告
  • 天津市南开区网站开发有限公司/生猪价格今日猪价
  • 印刷网站开发的可行性报告/知乎推广优化
  • 网站型跟商城型/搜索到的相关信息