卷积神经网络(CNN)入门实践及Sequential 容器封装
学习链接:https://www.bilibili.com/video/BV1hE411t7RN?t=1.4&p=22
推荐网站:CS231n 用于计算机视觉的深度学习
一、CNN 核心层的作用与原理
在搭建模型前,先明确 CNN 中各核心层的功能:
层类型 | 作用 | 关键参数示例 |
---|---|---|
卷积层(Conv2d) | 提取图像局部特征(如边缘、纹理),通过卷积核实现特征映射 | Conv2d(3, 32, 5, padding=2) (输入通道 3,输出通道 32,核大小 5×5,填充 2) |
池化层(MaxPool2d) | 下采样压缩特征图,减少计算量,同时保留关键特征 | MaxPool2d(2) (2×2 窗口做最大池化) |
展平层(Flatten) | 将二维特征图转换为一维向量,为全连接层做准备 | —— |
全连接层(Linear) | 对提取的特征做非线性变换,最终实现分类或回归任务 | Linear(1024, 64) (输入维度 1024,输出维度 64) |
二、两种 CNN 模型构建方式对比
我们可以用 “分步骤定义层”和“Sequential 容器封装” 两种方式构建完全等价的 CNN 模型。
方式 1:分步骤定义每一层
这种方式更直观,适合初学者理解每一层的执行顺序:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearclass Prayer(nn.Module):def __init__(self):super(Tudui, self).__init__()# 定义各层self.conv1 = Conv2d(3, 32, 5, padding=2) # 第一层卷积self.maxpool1 = MaxPool2d(2) # 第一层池化self.conv2 = Conv2d(32, 32, 5, padding=2) # 第二层卷积self.maxpool2 = MaxPool2d(2) # 第二层池化self.conv3 = Conv2d(32, 64, 5, padding=2) # 第三层卷积self.maxpool3 = MaxPool2d(2) # 第三层池化self.flatten = Flatten() # 展平层self.linear1 = Linear(1024, 64) # 第一层全连接self.linear2 = Linear(64, 10) # 第二层全连接(10类分类)def forward(self, x):# 按顺序执行各层x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.conv3(x)x = self.maxpool3(x)x = self.flatten(x)x = self.linear1(x)x = self.linear2(x)return x# 测试模型
prayer = Prayer()
print(prayer)
input = torch.ones((64, 3, 32, 32)) # 模拟输入:批量64、3通道、32×32图像
output = prayer(input)
print(output.shape) # 输出应为torch.Size([64, 10])
方式 2:Sequential 容器封装(更简洁)
当层的执行顺序很明确时,用Sequential
把层 “打包”,代码更简洁:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequentialclass Prayer(nn.Module):def __init__(self):super(Prayer, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 测试模型
# prayer = Prayer()
# print(tudui)
# input = torch.ones((64, 3, 32, 32)) # 模拟输入:批量64、3通道、32×32图像
# output = prayer(input)
# print(output.shape) # 输出应为torch.Size([64, 10])# 测试模型
prayer = Prayer()
print(prayer)
input = torch.ones((64, 3, 32, 32))
output = prayer(input)
print(output.shape) # 同样输出torch.Size([64, 10])
三、用 TensorBoard 可视化模型结构
from torch.utils.tensorboard import SummaryWriter# 初始化SummaryWriter,指定日志保存路径
writer = SummaryWriter("../logs_seq")
# 传入模型和测试输入,生成计算图
writer.add_graph(prayer, input)
writer.close()
运行代码后,在终端执行命令:
tensorboard --logdir=../logs_seq
然后打开浏览器访问http://localhost:6006
,就能看到模型的计算图了,每一层的连接关系一目了然~
四、模型应用场景
这个 CNN 模型的结构非常经典,适合作为图像分类任务的 “baseline(基准)”,比如:
- 对 CIFAR-10 数据集(10 类彩色小图像)做分类;
- 自定义小型图像数据集的分类任务;
- 作为更复杂 CNN 模型的 “基石”,在此基础上添加残差连接、注意力机制等模块。