当前位置: 首页 > news >正文

PyTorch 构建神经网络

组件作用
层(Layer)网络基本单元,如卷积层(Conv2d)、线性层(Linear),负责张量数据变换
模型(Model)由多层按逻辑组合而成的整体,实现从输入到输出的映射
损失函数衡量预测值与真实值的差距,如交叉熵损失(CrossEntropyLoss),是参数优化的目标
优化器通过反向传播更新模型参数以最小化损失,如 Adam、SGD

PyTorch 模型构建

1. nn.Module:可训练参数的 “管理者”

特点:所有带可学习参数的层(如 Conv2d、Linear)均继承自 nn.Module,能自动追踪参数,支持与模型容器结合使用。

用法:自定义模型需继承 nn.Module,在__init__中定义层,在forward中实现前向传播逻辑。

示例:定义一个简单线性层模块

python运行

import torch.nn as nn
class SimpleLinear(nn.Module):def __init__(self, in_dim, out_dim):super().__init__()self.linear = nn.Linear(in_dim, out_dim)  # 可学习参数由nn.Module管理def forward(self, x):return self.linear(x)

 nn.functional:纯函数式工具

特点:无参数的 “纯函数” 集合,如激活函数(ReLU)、池化(max_pool2d),需手动传入参数(若有),无法与模型容器直接结合。

注意:dropout 操作若用 nn.functional 实现,需手动区分训练 / 测试模式;而 nn.Dropout(继承自 nn.Module)可通过model.eval()自动切换状态。

三种模型构建方法

1. 直接继承 nn.Module:最灵活

适用于复杂网络结构,需手动定义每一层的连接逻辑。例如构建含批归一化的全连接网络:

python运行

import torch.nn.functional as F
class FCModel(nn.Module):def __init__(self, in_dim=784, n_hidden=300, out_dim=10):super().__init__()self.flatten = nn.Flatten()  # 展平28*28图像self.linear1 = nn.Linear(in_dim, n_hidden)self.bn1 = nn.BatchNorm1d(n_hidden)  # 批归一化def forward(self, x):x = self.flatten(x)x = F.relu(self.bn1(self.linear1(x)))  # 前向传播逻辑return x

2. nn.Sequential:按序堆叠,快速高效

适合层与层按顺序连接的简单网络,支持三种定义方式:

可变参数:直接传入层实例,无需命名

python运行

seq = nn.Sequential(nn.Flatten(), nn.Linear(784, 300), nn.ReLU())
  • add_module:为每层指定名称,便于后续查看

    python运行

    seq = nn.Sequential()
    seq.add_module("flatten", nn.Flatten())
    seq.add_module("linear1", nn.Linear(784, 300))
    
  • OrderedDict:用有序字典定义,兼顾顺序与命名

    python运行

    from collections import OrderedDict
    seq = nn.Sequential(OrderedDict([("flatten", nn.Flatten()),("linear1", nn.Linear(784, 300))
    ]))
    

 nn.Sequential 封装残差块:

python运行

class ResBlockWrapper(nn.Module):def __init__(self):super().__init__()# 用nn.Sequential封装残差块内的层self.res_block = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1),nn.BatchNorm2d(64))def forward(self, x):return F.relu(x + self.res_block(x))

从自定义模块到 ResNet18

1. 定义两种残差块

ResNet18 包含两种残差块,分别处理 “维度不变” 和 “维度下采样” 场景:

python运行

class RestNetBasicBlock(nn.Module):# 基础残差块:输入输出维度一致,无需额外调整def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))return F.relu(x + out)  # 残差连接class RestNetDownBlock(nn.Module):# 下采样残差块:用1×1卷积调整输入维度,适配残差连接def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride[0], padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride[1], padding=1)self.bn2 = nn.BatchNorm2d(out_channels)# 1×1卷积调整输入通道和分辨率self.extra = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, stride[0]),nn.BatchNorm2d(out_channels))def forward(self, x):extra_x = self.extra(x)  # 维度调整out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))return F.relu(extra_x + out)

2. 组合成 ResNet18 架构

基于两种残差块,按 “初始卷积→残差层→全局池化→全连接” 的顺序构建 ResNet18,适配 3 通道的人脸图像:

python运行

class RestNet18(nn.Module):def __init__(self, num_classes):  # num_classes:人脸类别数super().__init__()# 初始层:降维+下采样self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)# 4个残差层:2个基础块+2个下采样块组合self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1), RestNetBasicBlock(64, 64, 1))self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2,1]), RestNetBasicBlock(128, 128, 1))self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2,1]), RestNetBasicBlock(256, 256, 1))self.layer4 = nn.Sequential(RestNetDownBlock(256, 512, [2,1]), RestNetBasicBlock(512, 512, 1))# 分类头:全局平均池化+全连接self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.fc = nn.Linear(512, num_classes)def forward(self, x):# 前向传播:按层顺序执行x = self.bn1(self.conv1(x))x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)  # 展平为一维向量return self.fc(x)

总结

PyTorch 模型构建的核心在于 “灵活组合”:通过 nn.Module 管理可训练参数,用 nn.Sequential 等容器简化层连接,结合自定义模块(如残差块)可实现复杂架构。从基础全连接网络到 ResNet18

http://www.dtcms.com/a/398697.html

相关文章:

  • 人工智能医疗系统灰度上线与评估:技术框架实践分析python版(下)
  • 网站推广费用一般多少钱设计工作室logo
  • Eclipse配置tomcat+创建javaweb项目
  • 做国际网站找阿里西安市今天发生的重大新闻
  • 深圳工程建设交易服务中心网站郑州做网站zzmshl
  • Flink-SQL通过过滤-解析-去重-聚合计算写入到MySQL表
  • 公司网站建设记哪个科目网站建设对企业的要求
  • 汕头网页设计制作金华seo扣费
  • Vue电商数据分析大屏开发
  • 【开题答辩全过程】以 bilibili排行榜的数据分析与可视化为例,包含答辩的问题和答案
  • AI性能对决!蓝耘MaaS平台在2025大模型测评中如何脱颖而出
  • 新能源知识库(109)什么是频率死区?
  • Linux开发——开发板介绍及裸机程序设计
  • 百度网站推广关键词怎么查凡科微信小程序怎么样
  • 定制网站开发接活wordpress固定链接设置技巧
  • HTTP代理HTTP(S)、SOCKS5有哪些作用?
  • vue3+TS 前端调用海康摄像头视频流,后端用 Node.js 做 RTSP 转 WebSocket-FLV 转发,并且前后端优化延迟方案
  • 计算机视觉(opencv)练习——抠图(图像裁剪与轮廓提取)
  • 网站建设知识点的总结怎么做网站一个平台
  • 西安做网站的在网站后台设置wap模板目录
  • 软件行业|Parasoft与IAR的嵌入式DevOps测试集成
  • 设计模式-状态模式详解
  • 微信小程序通用弹窗组件封装与动画实现
  • 「日拱一码」099 数据处理——降维
  • 速通ACM省铜第十三天 赋源码(Watermelon)
  • 【C++进阶系列】:位图和布隆过滤器(附模拟实现的源码)
  • 洛阳网站建设建站系统怎么删除网站的死链
  • 山东省城乡建设厅网站wordpress academia
  • 广州番禺服装网站建设济南网站优化
  • 下载huggingface中数据集/模型