当前位置: 首页 > news >正文

PyTorch中卷积层torch.nn.Conv2d

在 PyTorch 中,卷积层主要由 torch.nn.Conv1dtorch.nn.Conv2d 和 torch.nn.Conv3d 实现,分别对应一维、二维和三维卷积操作。以下是详细说明:

1. 二维卷积 (Conv2d) - 最常用

import torch.nn as nn

# 基本参数
conv = nn.Conv2d(
    in_channels=3,    # 输入通道数 (如RGB图像为3)
    out_channels=16,  # 输出通道数/卷积核数量
    kernel_size=3,    # 卷积核大小 (可以是int或tuple, 如(3,3))
    stride=1,         # 步长 (默认1)
    padding=1,        # 边界填充 (默认0)
    dilation=1,       # 空洞卷积参数 (默认1)
    groups=1,         # 分组卷积参数 (默认1)
    bias=True         # 是否使用偏置 (默认True)
)
计算输出尺寸:

比如:高度

 

2. 使用示例 

import torch

# 输入张量 (batch_size=4, 通道=3, 高=32, 宽=32)
x = torch.randn(4, 3, 32, 32)

# 卷积层
conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
output = conv(x)
print(output.shape)  # torch.Size([4, 16, 32, 32])

3. 特殊卷积类型

(1) 空洞卷积 (Dilated Convolution)
nn.Conv2d(3, 16, kernel_size=3, dilation=2)  # 扩大感受野
(2) 分组卷积 (Grouped Convolution)
nn.Conv2d(16, 32, kernel_size=3, groups=4)  # 将输入/输出通道分为4组
(3) 深度可分离卷积 (Depthwise Separable)
# 等价于 groups=in_channels
depthwise = nn.Conv2d(16, 16, kernel_size=3, groups=16)
pointwise = nn.Conv2d(16, 32, kernel_size=1)  # 1x1卷积

4. 一维和三维卷积

Conv1d (时序数据/文本)
conv1d = nn.Conv1d(in_channels=256, out_channels=100, kernel_size=3)
# 输入形状: (batch, channels, sequence_length)
Conv3d (视频/体积数据)
conv3d = nn.Conv3d(1, 32, kernel_size=(3,3,3))
# 输入形状: (batch, channels, depth, height, width)

5. 转置卷积 (反卷积)

nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2)  # 常用于上采样

6. 初始化权重

# 常用初始化方法
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
nn.init.constant_(conv.bias, 0.1)

7. 可视化卷积核

import matplotlib.pyplot as plt

weights = conv.weight.detach().cpu()
plt.figure(figsize=(10,5))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(weights[i,0], cmap='gray')
plt.show()

8. 总结:

  1. 卷积核参数共享,大大减少参数量

  2. padding='same' 可保持输入输出尺寸相同 (PyTorch 1.9+)

  3. 通常配合 BatchNorm 和 ReLU 使用

  4. 使用 print(conv) 可查看层结构

实际应用中,卷积层常与池化层交替使用构建CNN架构,如:

self.conv_block = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(2)
)

http://www.dtcms.com/a/108218.html

相关文章:

  • Android 切换prefer APN后建立PDN的日志分析
  • ubuntu改用户权限
  • AI调研 | Omnisql模型家族调研与实测
  • ‌Windows 与 Linux网络命令速查表,含常用场景及参数说明
  • 使用高德api实现天气查询
  • 多电机显示并排序
  • WHAT - 如何理解中间件
  • WPF学习路线
  • 关于Gstreamer+MPP硬件加速推流问题:视频输入video0被占用
  • MYSQL实现获取某个经纬度区域内的数据
  • Cesium系列:从入门到实践,打造属于你的3D地球应用
  • 为 Jenkins Agent 添加污点(Taint)容忍度(Toleration)
  • Dubbo分布式框架学习(1)
  • vue省市区懒加载,用el-cascader 新增和回显
  • 多模态大模型笔记
  • Compressed串行端口终端应用程序(MAC 、WIN、LINUX)打包下载
  • 高级java每日一道面试题-2025年3月19日-Web篇-防止表单重复提交的方法有哪些?
  • MySQL联合查询
  • vector的学习使用(1)
  • Cjson的创建和解析
  • 【Python】KNN:K-NearestNeighbor 学习指南
  • Vue3+Cesium+vite 入门- 项目搭建
  • HAL库 通过USB Boot进行APP程序升级
  • window11 通过cmd命令行安装 oh my zsh 的教程
  • VMware上的windows虚拟机安装使用Docker方法
  • MySQL篇(二): 核心知识深度聚簇解析:索引、非聚簇索引、回表查询、覆盖索引、超大分页处理、索引创建原则与索引失效场景
  • TDengine 权限管理与安全配置实战(二)
  • Redhat8.10 离线安装Snipe-IT v8.0.4 版本
  • 计算机网络中科大 - 第1章 结构化笔记(详细解析)
  • PostgreSQL pg_repack 重新组织表并释放表空间