当前位置: 首页 > news >正文

PyTorch实现多输入输出通道的卷积操作

本文通过代码示例详细讲解如何在PyTorch中实现多输入通道和多输出通道的卷积运算,并对比传统卷积与1x1卷积的实现差异。


1. 多输入通道互相关运算

当输入包含多个通道时,卷积核需要对每个通道分别进行互相关运算,最后将结果相加。以下是实现代码:

import torch
from d2l import torch as d2l

def corr2d_multi_in(X, K):
    return sum(d2l.corr2d(x, k) for x, k in zip(X, K))

验证输出
输入一个2通道的3x3张量和一个2通道的2x2卷积核,输出结果为2x2张量:

X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
                 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])

print(corr2d_multi_in(X, K))

输出结果:

tensor([[ 56.,  72.],
        [104., 120.]])

2. 多输出通道互相关运算

通过堆叠多个卷积核,可以实现多输出通道。以下代码展示了如何生成3个输出通道:

def corr2d_multi_in_out(X, K):
    return torch.stack([corr2d_multi_in(X, k) for k in K], 0)

K = torch.stack((K, K+1, K+2), 0)  # 堆叠3个卷积核
print("卷积核形状:", K.shape)

输出结果:

卷积核形状: torch.Size([3, 2, 2, 2])

运行多通道卷积:

print(corr2d_multi_in_out(X, K))

输出结果:

tensor([[[ 56.,  72.],
         [104., 120.]],

        [[ 76., 100.],
         [148., 172.]],

        [[ 96., 128.],
         [192., 224.]]])

3. 1x1卷积的优化实现

1x1卷积可通过矩阵乘法高效实现,尤其适用于通道维度调整。以下是对比传统卷积与1x1卷积的代码:

def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.reshape((c_i, h * w))       # 展平空间维度
    K = K.reshape((c_o, c_i))        # 展平卷积核
    Y = torch.matmul(K, X)           # 矩阵乘法
    return Y.reshape((c_o, h, w))    # 恢复形状

# 生成随机输入和卷积核
X = torch.normal(0, 1, (3, 3, 3))    # 3通道3x3输入
K = torch.normal(0, 1, (2, 3, 1, 1)) # 2输出通道的1x1卷积核

# 验证两种方法结果一致
Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)
assert float(torch.abs(Y1 - Y2).sum()) < 1e-6  # 误差极小

总结

  • 多输入通道:对每个通道独立进行卷积后求和。

  • 多输出通道:通过堆叠多个卷积核实现不同输出。

  • 1x1卷积:本质是通道间的线性组合,可通过矩阵乘法高效实现。

通过上述代码示例,读者可以深入理解多通道卷积的实现原理,并掌握优化技巧。


注意:运行代码需安装PyTorch和d2l库。完整代码请参考文中示例。

相关文章:

  • 非 root 用户运行 Docker 容器和同步主机和容器权限
  • vue入门:插槽
  • AI 重构 Java 遗留系统:从静态方法到 Spring Bean 注入的自动化升级
  • ocr python库
  • 《深度剖析分布式软总线:软时钟与时间同步机制探秘》
  • git清理已经删除的远程分支
  • 大模型在儿童急性淋巴细胞白血病(ALL)-初治患者诊疗中应用的研究报告
  • git commit时自动生成Change-ID
  • XTuner学习
  • WHAT - Typescript 定义元素类型
  • 大数据(7.2)Kafka万亿级数据洪流下的架构优化实战:从参数调优到集群治理
  • 数据结构与算法之ACM Fellow-算法3.4 散列表
  • Unity 设置弹窗Tips位置
  • LLaMA-Factory从安装到微调全流程
  • Linux上搭建NFS共享存储
  • SpringBoot项目集成Seata 2.0.0
  • Kubernetes核心架构:从组件协同到工作原理
  • LED恒流驱动驱动电路原理图 LM3406HV-Q1
  • SpringBoot 为何启动慢
  • 第1课:MCP服务协议核心架构解析
  • 开发软件开发/刷关键词排名seo软件
  • 旅游网站设计与建设论文/软文大全
  • 仿牌外贸网站制作/国家卫生健康委
  • 除了昵图网还有什么做图网站/环球网疫情最新消息
  • wordpress 报表图形/网站是怎么优化的
  • 互联网网站制作公司/360搜索引擎优化