当前位置: 首页 > news >正文

卷积神经网络搭建及应用

代码实现:

import torch
print(torch.__version__)
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor())
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor())train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)for X,y in test_dataloader:print(f"Shape of X[N,C,H,W]:{X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")breakdevice = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16,64,5,1,2),nn.ReLU(),nn.Conv2d(64, 128,5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),)self.conv3 = nn.Sequential(nn.Conv2d(128,256,5,1,2),nn.ReLU(),)self.out = nn.Linear(256*7*7,10)def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)output = self.out(x)return outputmodel = CNN().to(device)
print(model)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for X ,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 ==0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num +=1
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches= len(dataloader)model.eval()test_loss = 0correct = 0with torch.no_grad():for X ,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_pj_loss = test_loss / num_batchestest_acy = correct / size * 100print(f"Avg loss: {test_pj_loss:>7f} \n Accuray: {test_acy:>5.2f}%")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)
i=10
for j in range(i):print(f"Epoch {j+1}\n----------")train(train_dataloader, model,loss_fn,optimizer)
print("Done!")
test(test_dataloader,model,loss_fn)

这段代码是一个基于 PyTorch 实现的卷积神经网络 (CNN) 训练流程,用于对 MNIST 手写数字数据集进行分类任务。整体结构清晰,包含了数据加载、模型定义、训练 / 测试函数实现和完整的训练流程。以下是代码解析:

1. 库导入与数据集加载

import torch
print(torch.__version__)  # 打印PyTorch版本
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

导入核心库:torch(PyTorch 核心)、nn(神经网络模块)、DataLoader(数据加载工具)、datasets(内置数据集)、ToTensor(图像转张量的转换工具)。

# 加载MNIST数据集
training_data = datasets.MNIST(root='data',        # 数据存储路径train=True,         # 加载训练集download=True,      # 若本地无数据则自动下载transform=ToTensor() # 将图像转为PyTorch张量(并归一化到[0,1])
)
test_data = datasets.MNIST(root='data', train=False,        # 加载测试集download=True, transform=ToTensor()
)

MNIST 是经典手写数字数据集(0-9),包含 60000 张训练图和 10000 张测试图,每张图是 28×28 的灰度图(单通道)。

transform=ToTensor()将 PIL 图像(0-255)转为张量(0.0-1.0),方便模型处理。

2. 数据加载器(DataLoader)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

DataLoader用于将数据集按batch_size(这里是 64)分批加载,支持自动打乱数据、多进程加载等功能。

后续训练时,模型将按批次处理数据(一次处理 64 张图),提高效率。

3. 数据格式验证

for X, y in test_dataloader:print(f"Shape of X[N,C,H,W]: {X.shape}")  # 输出:(64, 1, 28, 28)print(f"Shape of y: {y.shape} {y.dtype}")  # 输出:(64,) torch.int64break

验证数据维度:

X(输入图像):形状为[N, C, H, W],其中N=64(批次大小)、C=1(单通道灰度图)、H=28W=28(图像尺寸)。

y(标签):形状为[64],每个元素是 0-9 的整数(对应数字类别)。

4. 设备选择(CPU/GPU)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")

自动选择计算设备:优先使用 NVIDIA GPU(cuda),其次是 Apple 芯片 GPU(mps),最后是 CPU。

目的是利用 GPU 加速训练(CNN 计算量大,GPU 比 CPU 快得多)。

5. CNN 模型定义

class CNN(nn.Module):def __init__(self):super().__init__()# 第一个卷积块:卷积+激活+池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)# 第二个卷积块:两个卷积+激活+池化self.conv2 = nn.Sequential(nn.Conv2d(16, 64, 5, 1, 2),  # 省略参数名,顺序为:输入通道、输出通道、核大小、步长、填充nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),)# 第三个卷积块:卷积+激活(无池化)self.conv3 = nn.Sequential(nn.Conv2d(128, 256, 5, 1, 2),nn.ReLU(),)# 全连接层:输出10类(0-9)self.out = nn.Linear(256*7*7, 10)  # 输入维度需与卷积输出展平后一致def forward(self, x):x = self.conv1(x)       # 经过第一个卷积块x = self.conv2(x)       # 经过第二个卷积块x = self.conv3(x)       # 经过第三个卷积块x = x.view(x.size(0), -1)  # 展平为向量(batch_size, 特征数)output = self.out(x)    # 全连接层输出return output

模型继承自nn.Module(PyTorch 中所有神经网络的基类)。

卷积层设计逻辑

卷积操作通过nn.Conv2d实现,kernel_size=5(5×5 卷积核)、stride=1(步长 1)、padding=2(填充 2),保证卷积后图像尺寸不变((H - kernel + 2*padding)/stride + 1 = H)。

MaxPool2d(kernel_size=2)将图像尺寸缩小一半(28→14→7)。

维度变化过程(以输入 28×28 为例):

conv1:28×28 → 28×28(卷积)→ ReLU → 14×14(池化),通道数从 1→16。

conv2:14×14 → 14×14(16→64)→ ReLU → 14×14(64→128)→ ReLU → 7×7(池化),通道数 128。

conv3:7×7 → 7×7(128→256)→ ReLU,通道数 256。

展平后维度:256 * 7 * 7 = 12544,通过全连接层out输出 10 类(对应 0-9)。

6. 模型初始化

model = CNN().to(device)  # 初始化模型并移动到指定设备(CPU/GPU)
print(model)  # 打印模型结构

模型实例化后,通过.to(device)将所有参数移动到选定的计算设备(确保数据和模型在同一设备上)。

7. 训练函数(train)

def train(dataloader, model, loss_fn, optimizer):model.train()  # 设为训练模式(启用 dropout/batchnorm等训练特有的层)batch_size_num = 1  # 批次计数器for X, y in dataloader:X, y = X.to(device), y.to(device)  # 数据移到设备# 前向传播:计算预测值pred = model(X)# 计算损失(预测值与真实标签的差距)loss = loss_fn(pred, y)# 反向传播+参数更新optimizer.zero_grad()  # 清空上一轮梯度loss.backward()        # 计算梯度(反向传播)optimizer.step()       # 更新模型参数# 每100个批次打印一次损失loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number: {batch_size_num}]")batch_size_num += 1

核心功能:实现单轮训练(遍历所有训练数据一次)。

关键步骤:前向传播计算预测→计算损失→反向传播求梯度→优化器更新参数。

model.train():启用训练模式(例如,BatchNorm 层会更新均值和方差)。

8. 测试函数(test)

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集总样本数(10000)num_batches = len(dataloader)   # 测试集批次数model.eval()  # 设为评估模式(关闭 dropout/batchnorm等)test_loss = 0correct = 0with torch.no_grad():  # 关闭梯度计算(节省内存,加速计算)for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)  # 前向传播(无反向传播)test_loss += loss_fn(pred, y).item()  # 累计损失# 计算正确预测数(取预测概率最大的类别与真实标签比较)correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_avg_loss = test_loss / num_batchestest_acc = correct / size * 100print(f"Avg loss: {test_avg_loss:>7f} \n Accuracy: {test_acc:>5.2f}%")

核心功能:评估模型在测试集上的性能(损失和准确率)。

关键步骤:关闭梯度计算(torch.no_grad())→ 前向传播→计算总损失和正确数→计算平均指标。

model.eval():切换到评估模式(例如,BatchNorm 层使用训练时的均值 / 方差,不更新)。

9. 训练配置与执行

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失(适用于分类任务,内置SoftMax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Adam优化器,学习率0.01# 训练10个epoch
epochs = 10
for epoch in range(epochs):print(f"Epoch {epoch+1}\n----------")train(train_dataloader, model, loss_fn, optimizer)  # 训练一轮
print("Done!")# 最终测试
test(test_dataloader, model, loss_fn)

损失函数CrossEntropyLoss是分类任务的常用损失,直接接收模型输出(未经过 SoftMax)和标签。

优化器Adam是常用的自适应学习率优化器,收敛较快。

训练循环:迭代 10 个epoch(每个 epoch 遍历一次训练集),最后用测试集评估最终性能。

总结

这段代码是一个完整的 CNN 训练 pipeline,从数据加载到模型定义,再到训练 / 测试流程,逻辑清晰,可直接运行(需安装 PyTorch 和 torchvision)。在 MNIST 数据集上,合理调参后准确率通常可达到 99% 以上。

http://www.dtcms.com/a/356006.html

相关文章:

  • 对象之间属性拷贝(Bean Mapping)的工具MapStruct 和 BeanUtils
  • 多据点协作下的数据库权限与版本管理实战
  • BeforeEach与AfterEach注解的使用
  • React学习教程,从入门到精通, ReactJS - 安装:初学者指南(3)
  • iPhone17新品曝光!未来已来主题发布会即将登场
  • CSS入门学习
  • Vim 相关使用
  • Dify 从入门到精通(第 61/100 篇):Dify 的监控与日志分析(进阶篇)
  • 笔记本电脑蓝牙搜索不到设备-已解决
  • LoRA加入嵌入层、及输出头解析(63)
  • 实测阿里图像编辑模型Qwen-Image-Edit:汉字也能无痕修改(附实测案例)
  • 【 MYSQL | 基础篇 函数与约束 】
  • 响应式编程之Flow框架
  • cmd 中设置像 linux 一样设置别名(alias)
  • Xshell自动化脚本大赛实战案例及深度分析
  • 谷歌RecLLM,大模型赋能对话推荐算法系统
  • TUN模式端口冲突 启动失败如何解决?
  • hintcon2025No Man‘s Echo
  • 【Web安全】反序列化安全漏洞全解析:从原理到实战测试指南
  • Vue3 Pinia 中 store.$dispose()的用法说明
  • Vue3组件加载顺序
  • vue项目运行后自动在浏览器打开
  • 使用npm init vue@latest 基于vite创建的vue项目
  • 特色领域数据集:以数据之力,赋能多元行业发展
  • three 点位图
  • HT338立体声D类音频功放
  • 消息推送与 WebSocket 学习
  • Node.js终极文本转图指南
  • 基于SpringBoot的学科竞赛管理系统
  • 请详细介绍RuntimeInit.java中的MethodAndArgsCaller类