当前位置: 首页 > news >正文

深度学习------专题《图像处理项目》

目录

一、数据加载与可视化:先看清 “要分类的图长啥样”

1. 数据预处理:用transforms给图像 “标准化”

2. 加载数据集:Dataset+DataLoader的 “黄金组合”

3. 可视化验证:确认数据没 “读错”

二、搭建 CNN 网络:卷积、池化、全连接的 “接力赛”

三、训练模型:梯度下降的 “迭代游戏”

1. 损失函数与优化器选择

2. 训练循环:耐心 + 细节

四、测试模型:看看分类准不准

1. 单批数据测试:直观看预测结果

2. 整体准确率统计:量化模型性能

五、总结:从入门到踩坑,我的实战心得


从零搭建 CNN!用 PyTorch 搞定 CIFAR-10 图像分类(附完整流程 + 踩坑记录)

      今天跟着老师完成了 CIFAR-10 图像分类的实战项目,从 “数据加载” 到 “模型训练测试”,终于把卷积神经网络(CNN)的落地流程摸透了!过程中踩了不少小坑,也总结出一些新手易上手的经验,分享给刚入门深度学习的朋友~

一、数据加载与可视化:先看清 “要分类的图长啥样”

      做图像分类的第一步,是把数据集 “喂” 给模型。CIFAR-10 包含 10 类物体(飞机、汽车、鸟、猫等),每类有 6000 张 32×32 的彩色图。PyTorch 的torchvision工具包让数据处理变得简单,但第一次操作还是有不少细节要注意。

1. 数据预处理:用transforms给图像 “标准化”

      图像不能直接丢给模型,得先做预处理。我用transforms.Compose把 “转 Tensor” 和 “标准化” 两个操作串成流水线:

import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),  # PIL转Tensor,像素值缩到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 三通道分别标准化
])

踩坑提醒:最开始我漏了Normalize,结果模型训练时损失下降特别慢。后来才知道,标准化能让数据分布更均匀,模型收敛速度会快很多~

2. 加载数据集:Dataset+DataLoader的 “黄金组合”

torchvision.datasets.CIFAR10加载数据集,再用DataLoader批量管理数据:

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2
)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2
)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  • shuffle=True:训练时打乱数据,避免模型 “死记硬背” 数据顺序;

  • batch_size=4:每次给模型喂 4 张图(笔记本显存小,设太小会慢,太大容易爆显存);

  • num_workers=2:用多进程加速数据加载(Windows 下别设太大,容易报错)。

3. 可视化验证:确认数据没 “读错”

为了确保数据加载正确,我写了个imshow函数可视化图像:

import matplotlib.pyplot as plt
import numpy as npdef imshow(img):img = img / 2 + 0.5  # 反标准化(因为之前缩到[-1,1]了)npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))  # Tensor是(C,H,W),转成(H,W,C)才能显示plt.show()# 取一批数据可视化
dataiter = iter(trainloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

      运行后能看到 4 张拼接的图,还有对应的标签(比如我这次看到 “bird car cat plane”),确认数据加载没问题~

二、搭建 CNN 网络:卷积、池化、全连接的 “接力赛”

CNN 的核心逻辑是 “卷积提取特征→池化缩小尺寸→全连接分类”。我搭了一个 “两层卷积 + 两层池化 + 两层全连接” 的基础网络,代码如下:

import torch.nn as nn
import torch.nn.functional as Fclass CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()# 第一层卷积:3个输入通道(彩色图),16个卷积核,5×5大小self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 池化后尺寸缩小一半# 第二层卷积:16个输入通道,36个卷积核,3×3大小self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层:先把卷积输出展平,再接128个神经元,最后分类到10类self.fc1 = nn.Linear(36 * 6 * 6, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):# 卷积→激活→池化x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))# 展平(把三维特征图转成一维向量)x = x.view(-1, 36 * 6 * 6)# 全连接→激活→最终分类(最后一层不用激活,损失函数会处理)x = F.relu(self.fc1(x))x = self.fc2(x)return xnet = CNNNet()

关键理解

  • 卷积层的out_channels是 “卷积核数量”,数量越多,能提取的特征越丰富;

  • kernel_size(卷积核大小):5×5 适合先抓 “大特征”,3×3 适合细化特征;

  • 池化层(MaxPool2d):缩小特征图尺寸、减少计算量,还能防止过拟合;

  • view(-1, ...):把三维的特征图 “展平” 成一维向量,才能输入全连接层。

我还特意算了模型参数总数(用sum(x.numel() for x in net.parameters())),结果是 173742 个参数,笔记本也能轻松运行~

三、训练模型:梯度下降的 “迭代游戏”

训练模型的核心是 “前向传播→算损失→反向传播→更新参数”。PyTorch 把这些步骤封装得很简洁,但几个细节容易踩坑。

1. 损失函数与优化器选择

分类任务用CrossEntropyLoss(它自带 Softmax,不用手动加),优化器选 “带动量的 SGD”(比普通 SGD 收敛更顺滑):

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  • lr=0.001(学习率):太大容易 “跳过” 最优解,太小则收敛慢;

  • momentum=0.9(动量):让梯度更新更平稳,减少震荡。

2. 训练循环:耐心 + 细节

训练的核心循环长这样:

for epoch in range(10):  # 总共训练10轮running_loss = 0.0for i, data in enumerate(trainloader, 0):# 取出数据并放到设备(CPU/GPU)inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 梯度清零(超重要!否则梯度会累加,导致训练爆炸)optimizer.zero_grad()# 前向传播 → 计算损失 → 反向传播 → 更新参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印损失,看训练趋势running_loss += loss.item()if i % 2000 == 1999:  # 每2000批打印一次平均损失print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('Finished Training')

踩坑记录:最开始我忘了写optimizer.zero_grad(),结果损失值 “一路飙升”,模型完全不收敛。原来每次反向传播后,梯度会留在参数里,必须手动清零才能正确更新!

训练过程中,损失从 2.2 左右慢慢降到 1.3 左右,说明模型在 “学习” 了~

四、测试模型:看看分类准不准

训练完要在测试集验证效果,步骤是 “加载测试数据→模型预测→对比真实标签”。

1. 单批数据测试:直观看预测结果

先取一批测试图,看看模型的分类效果:

dataiter = iter(testloader)
images, labels = next(dataiter)# 显示测试图
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))# 模型预测
outputs = net(images)
_, predicted = torch.max(outputs, 1)  # 取概率最大的类别
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

我运行后,真实标签是 “cat ship ship plane”,但模型预测为 “frog dog deer horse”—— 前几个没猜对,说明模型还有优化空间(比如增加训练轮数、调整学习率),但流程是对的。

2. 整体准确率统计:量化模型性能

如果要评估模型在所有测试样本上的表现,可以写个循环统计正确率:

correct = 0
total = 0
with torch.no_grad():  # 测试时不用算梯度,节省内存for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'模型在10000张测试图上的准确率:{100 * correct // total}%')

我跑出来准确率大概 50% 左右,比 “随机猜(10%)” 好,但和论文里的高性能模型(比如 ResNet 能到 90%+)差距还大,说明这个基础 CNN 还有很大优化空间~

五、总结:从入门到踩坑,我的实战心得

今天的项目让我对 PyTorch 图像处理流程有了清晰认知:

  1. 数据是基础transforms预处理、Dataset+DataLoader加载、可视化验证,每一步都不能少;

  2. CNN 有规律:卷积(提特征)→池化(降维)→全连接(分类),层与层的维度要对应好;

  3. 训练看细节:梯度清零、学习率调整、设备(CPU/GPU)部署,细节决定训练成败;

  4. 测试见真章:单批预测看直观效果,整体准确率量化性能,这样才知道模型有没有真的 “学会”。

接下来打算试试增加训练轮数、换用 Adam 优化器,或者加深网络结构,看看能不能提高准确率~如果有同样在学 PyTorch 的朋友,欢迎交流踩坑经验呀~

http://www.dtcms.com/a/412336.html

相关文章:

  • 阿里云iis放网站织梦系统怎么做单页网站
  • 企业网站推广 知乎wordpress媒体库在哪
  • 关于红黑树删除节点操作的完整推导
  • 深圳做网站报价 网站
  • git reset --soft <commit>和 git revert <commit>的区别
  • Unity-角色控制器
  • 比价网站源码网站关键词优化方法
  • 模板网站的域名是什么意思网络服务器租赁
  • Linux第二十一讲:网络层
  • 【FreeRTOS】第七课(3):任务间的通信——使用队列集优化程序架构
  • SQL语句详细使用说明 - 适合小白入门
  • 天水网站建设惠普网站暂时关闭 seo
  • 做网站如何对接支付gpu服务器租用价格
  • 检查一个字符串是否包含所有长度为K的二进制子串
  • 做网站需要多少空间芜湖网站建设兼职
  • 森动网网站建设好吗自己开发app要钱吗
  • 携程网站建设在阿里巴巴上做网站需要什么条件
  • 北京建设部网站职称网站建设硬件投入表
  • 视频防录屏软件为什么受欢迎?---以点盾云为例
  • 开源AI工具Mobile-Use
  • 做课件ppt网站上海十大网站建设
  • 新乡网站建设策划ftp网站怎么看后台的代码
  • numpy谨慎升级
  • 微信公众号登录wordpress网站吗免费自助小型网站
  • 站长之家官网做co的网站
  • 网站开发语言哪个好聊天网站模板
  • 河南省住房城乡和建设厅网站首页货代到哪个网站开发客户
  • 一元夺宝网站建设2017安徽安庆怎么样
  • 个人网站建设费用做植物提取物的专业网站
  • 做网站的周记10个免费的黑科技网站