当前位置: 首页 > news >正文

卷积层里的多输入多输出通道

多输入多输出通道

  • 彩色图片可能有 RGB 三个通道
  • 转换为灰度会丢失信息
  • 每个通道都有一个卷积核,结果是所有通道卷积结果的和在这里插入图片描述

多个输入通道

  • 输入 X \mathbf{X} X : c i × n h × n w c_i \times n_h \times n_w ci×nh×nw
  • W \mathbf{W} W : c i × k h × k w c_i \times k_h \times k_w ci×kh×kw
  • 输出 Y \mathbf{Y} Y : m h × m w m_h \times m_w mh×mw

Y = ∑ i = 0 c i X i , : , : ⋆   m a t h b f W i , : , : \mathbf{Y} = \sum_{i=0}^{c_i} \mathbf{X}_{i,:,:} \star \ mathbf{W}_{i,:,:} Y=i=0ciXi,:,: mathbfWi,:,:

多个输出通道

  • 无论有多少输入通道,到目前为止我们只用到单输出通道

  • 我们可以有多个三维卷积核,每个核生成一个输出通道

  • 输入 X \mathbf{X} X: c i × n h × n w c_i \times n_h \times n_w ci×nh×nw

  • W \mathbf{W} W: c o × c i × k h × k w c_o \times c_i \times k_h \times k_w co×ci×kh×kw

  • 输出 Y \mathbf{Y} Y: c o × m h × m w c_o \times m_h \times m_w co×mh×mw

Y i , : , : = X ⋆ W i , : , : , : for  i = 1 , … , c o \mathbf{Y}_{i,:,:} = \mathbf{X} \star \mathbf{W}_{i,:,:,:} \quad \text{for } i = 1, \ldots, c_o Yi,:,:=XWi,:,:,:for i=1,,co

多个输入通道和输出通道

  • 每个输出通道可以认为是在识别特定模式
    在这里插入图片描述

  • 输入通道核识别并组合输入中的模式

1×1 卷积层

k h = k w = 1 k_h=k_w=1 kh=kw=1 是一个受欢迎的选择。不识别空间模式,只是相当于输入形状为 n h n w × c i n_hn_w×c_i nhnw×ci,权重为 c o × c i c_o×c_i co×ci 的全连接层。

卷积层的计算

  • 输入 X \mathbf{X} X: c i × n h × n w c_i \times n_h \times n_w ci×nh×nw
  • W \mathbf{W} W: c o × c i × k h × k w c_o \times c_i \times k_h \times k_w co×ci×kh×kw
  • 偏差 B \mathbf{B} B: c o × 1 c_o \times 1 co×1
  • 输出 Y \mathbf{Y} Y: c o × m h × m w c_o \times m_h \times m_w co×mh×mw

Y = X ⋆ W + B \mathbf{Y} = \mathbf{X} \star \mathbf{W} + \mathbf{B} Y=XW+B

  • 计算复杂度(浮点运算数 FLOP) O ( c i c o k h k w m h m w ) O(c_i c_o k_h k_w m_h m_w) O(cicokhkwmhmw)

    • c i = c o = 100 c_i = c_o = 100 ci=co=100
    • k h = k w = 5 ⇒ 1 GFLOP k_h = k_w = 5 \Rightarrow 1 \text{GFLOP} kh=kw=51GFLOP
    • m h = m w = 64 m_h = m_w = 64 mh=mw=64
  • 10 层,1M 样本,10 PFlops

    • CPU: 0.15 TF = 18h
    • GPU: 12 TF = 14min

总结

  • 输出通道数是卷积层的超参数
  • 每个输入通道有独立的二维卷积核,所有通道结果相加得到一个输出通道结果
  • 每个输出通道有独立的三维卷积核

代码实现

实现一下多输入通道互相关运算并验证互相关运算的输出

import torch
from d2l import torch as d2l
# 这个导入的配置我还是使用了很多的方法才解决的,但是还是有些版本不匹配的情况在这里面但是还是不影响我运行下面的代码

def corr2d_multi_in(X, K):
    # 先遍历“X”和“K”的第0个维度(通道维度),再把它们加在一起
    return sum(d2l.corr2d(x, k) for x, k in zip(X, K))
    # zip将输入张量x与内核张量k的对应通道配对

X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
               [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
# 表示X的形状为2*3*3
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])
# 表示内核K的形状为2*2*2

corr2d_multi_in(X, K) # 输出的形状为2*2,上面有图形进行解释

# output : 
tensor([[ 56.,  72.],
        [104., 120.]])

计算多个通道的输出的互相关函数

def corr2d_multi_in_out(X, K):
    # 迭代“K”的第0个维度,每次都对输入“X”执行互相关运算。
    # 最后将所有结果都叠加在一起
    return torch.stack([corr2d_multi_in(X, k) for k in K], 0)

K = torch.stack((K, K + 1, K + 2), 0)
K.shape

corr2d_multi_in_out(X, K)

在这里插入图片描述
下面验证一下 1*1 卷积等价于一个全连接

def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.reshape((c_i, h * w)) # 将X拍成一个二维的矩阵
    K = K.reshape((c_o, c_i))
    # 全连接层中的矩阵乘法
    Y = torch.matmul(K, X)
    return Y.reshape((c_o, h, w)) #重新变回去

X = torch.normal(0, 1, (3, 3, 3))
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)
assert float(torch.abs(Y1 - Y2).sum()) < 1e-6

QA 思考

Q1:怎么理解不识别空间模式?
A1:对于 1*1 卷积核来说只看了一个像素点,并没有看这个像素和边上像素的关系是啥。

后记

自己写了一点代码,主要是为了理解上述的操作:

import torch
from torch import nn


def corr2d(X, K):
    """
    手动实现二维互相关操作。
    X: 输入张量,形状为 (H, W)
    K: 卷积核,形状为 (kH, kW)
    返回: 互相关结果,形状为 (out_H, out_W)
    """
    H, W = X.shape
    kH, kW = K.shape
    # 输出的高度和宽度
    out_H = H - kH + 1
    out_W = W - kW + 1
    # 初始化输出张量
    Y = torch.zeros((out_H, out_W))
    # 计算互相关
    for i in range(out_H):
        for j in range(out_W):
            Y[i, j] = (X[i:i + kH, j:j + kW] * K).sum()
    return Y


def corr2d_multi_in(X, K):
    """
    多通道输入的二维互相关操作。
    X: 输入张量,形状为 (C, H, W),其中 C 是通道数
    K: 卷积核,形状为 (C, kH, kW)
    返回: 互相关结果,形状为 (out_H, out_W)
    """
    # 对每个通道分别计算互相关,并将结果相加
    return sum(corr2d(x, k) for x, k in zip(X, K))


# 输入张量 X,形状为 (2, 3, 3)
X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
                  [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])

# 卷积核 K,形状为 (2, 2, 2)
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])

# 测试多通道互相关
result = corr2d_multi_in(X, K)
print(result)
print("======================================================================")


# 多输出通道的互相关函数
def corr2d_multi_in_out(X, K):
    # 迭代“K”的第0个维度,每次都对输入“X”执行互相关运算。
    # 最后将所有结果都叠加在一起
    return torch.stack([corr2d_multi_in(X, k) for k in K], 0)


# 通过将核张量`K`与`K+1`(`K`中每个元素加 1)和`K+2`连接起来,构造了一个具有 3 个输出通道的卷积核。
K = torch.stack((K, K + 1, K + 2), 0)
print(K.shape)

print(corr2d_multi_in_out(X, K))

print("======================================================================")


# define 1*1 Conv
def corr2d_multi_in_out_1x1(X, K):
    # X:输入张量,形状为 (C_in, H, W)
    # K:卷积核张量,形状为 (C_out, C_in, 1, 1)
    c_i, h, w = X.shape
    c_o = K.shape[0]
    # 输入张量展平
    X = X.reshape((c_i, h * w))  # 将X拍成一个二维的矩阵
    # 卷积核展平
    K = K.reshape((c_o, c_i))
    # 全连接层中的矩阵乘法
    Y = torch.matmul(K, X)
    return Y.reshape((c_o, h, w))  # 重新变回去

# (3, 3, 3),表示 3 个通道,每个通道是一个 3x3 的矩阵
X = torch.normal(0, 1, (3, 3, 3))
#  (2, 3, 1, 1),表示 2 个输出通道,每个通道对应一个 3x1x1 的卷积核
K = torch.normal(0, 1, (2, 3, 1, 1))

Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)

# 打印 Y1 和 Y2
print("Y1:")
print(Y1)
print("Y2:")
print(Y2)

# 计算差异
diff = torch.abs(Y1 - Y2).sum()
print("Difference between Y1 and Y2:", diff.item())

# 断言
assert diff < 1e-6

相关文章:

  • 论文笔记:ASTTN模型
  • LINUX 1
  • [Linux实战] Linux设备树原理与应用详解
  • 并发多线程八股
  • ML 聚类算法 dbscan|| OPTICS
  • 使用 glog 库的 CHECK 宏进行条件断言和错误检测
  • K-均值聚类
  • DeepBI如何探索流量种子,快速帮助产品扩展流量
  • 卷积神经网络(CNN)原理与实战:从LeNet到ResNet
  • C 语 言 --- 整 形 提 升
  • 第三章 devextreme-react/scheduler 定制属性学习
  • 第十届MathorCup高校数学建模挑战赛-A题:无车承运人平台线路定价问题
  • Oceanbase企业版安装(非生产环境)
  • MAC使用当前VScode总是报权限不足的错误,简单修改
  • 【Linux内核系列】:文件ELF格式详解
  • TypeScript 中 await 的详解
  • 通用目标检测技术选型分析报告--截止2025年4月
  • 从零构建大语言模型全栈开发指南:第四部分:工程实践与部署-4.2.2多模态数据处理:图像编码与文本对齐(实战代码示例)
  • OpenAI即将开源!DeepSeek“逼宫”下,AI争夺战将走向何方?
  • 人工智能基础知识笔记六:方差分析
  • 上海制造佳品汇大阪站即将启幕,泡泡玛特领潮出海
  • 万科再获深铁集团借款,今年已累计获股东借款近120亿元
  • 菲律宾中期选举初步结果出炉,杜特尔特家族多人赢得地方选举
  • “80后”德州市接待事务中心副主任刘巍“拟进一步使用”
  • 上海北外滩,未来五年将如何“长个子”“壮筋骨”?
  • 反制美国钢铝关税!印度拟对美国部分商品征收关税