当前位置: 首页 > news >正文

深度学习-卷积神经网络LeNet

1 背景

        LeNet 由AT&T贝尔实验室的研究员Yann LeCun在1989年提出,是最早发布的卷积神经网络之一。当时主要是为了应用于识别手写数字,并且取得了与支持向量机媲美的结果。这项工作代表了十多年来神经网络研究开发的成果。

2 原理

        LeNet的结构如上图所示,输入一张32*32大小的单通道图像,首先经过一次卷积得到6通道28*28大小的特征图,经过下采样(即汇聚层)得到6通道14*14的特征图,然后经过第二次卷积得到16通道10*10的特征图,经过第二次下采样得到16通道5*5的特征图。接下来将所有特征图展平得到长度为16*5*5=400的张量,将这个张量作为全连接层的输入,经过两个全连接层和一个高斯连接层得到长度为10的输出。

        这个模型中几乎所有的部分我们已经在前面的章节学过,只有最后一层需要补充介绍一下。高斯连接层(Gaussian Connection Layer)是一种在神经网络中使用的连接方式,其主要特点是使用固定参数的欧式径向基函数(Euclidean Radial Basis Function)进行计算。与常规的全连接层不同,高斯连接层的输出不经过激活函数,而是直接作为输出结果。这种设计使得高斯连接层在某些特定任务中能够更有效地处理数据。说白了就是最后一层不再像常规的全连接层那样训练,而是作者手工标注出了数字0~9的特征,然后通过计算输入的值与这10种特征向量的欧氏距离得到10个输出。

3 实现

3.1 模型定义

        下面的代码展示了模型各层具体结构,为了让模型更加通用,书中没有实现最后特制的高斯连接层。

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

        下面展示了每一层输出的形状(B,C,H,W):

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)

3.2 模型训练

        这里数据集选用Fashion-MNIST数据集。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

        为了使用GPU加速训练,书中重写了精度计算函数和训练函数,其主要修改的地方在于需要在GPU中参与计算的数据全部用to(device)函数复制到了显存中。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

        训练结果如下:


参考文献

[1]《动手学深度学习》,https://zh-v2.d2l.ai/

[2] LeCun, Y., Bottou, L., Bengio, Y., Haffner, P., & others. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), 2278–2324.

http://www.dtcms.com/a/573729.html

相关文章:

  • Ubuntu误删libaudit.so.1 导致系统无法正常使用、崩溃
  • 【深度学习5】多层感知机
  • 通过fluent HEC 来发送数据到splunk
  • 二叉树深度解析:核心概念与算法实现
  • 考研408--操作系统--day3--调度调度算法
  • 东莞做网站首选企业铭wordpress 4.5.4 漏洞
  • 消防做ccc去那个网站微信网页版客户端下载
  • 项目实战 | 新建校区网络安全项目:从搭建到交付
  • MHAF-YOLO:用于精确目标检测的多分支异构辅助融合YOLO
  • 从零到上线:Spring Boot 3 + Spring Cloud Alibaba + Vue 3 构建高可用 RBAC 微服务系统(超详细实战)
  • 优秀企业网站模板下载企业网络推广方案怎么做
  • Spring国际化语言切换不生效
  • 跨境S2B2C供应链系统推荐:核货宝外贸分销S2B2C平台深度赋能B端、极致服务C端
  • 【OS笔记24】:存储管理3-分页管理-页表与快表
  • 城乡和住房建设厅网站首页深圳网站建站的公司
  • 湖湘杯网络安全技能大赛参与形式
  • 网站怎么上传模板优化设计七年级上册英语答案
  • C++-19-类和对象
  • 深度学习_原理和进阶_PyTorch入门(2)后续语法2
  • C++ 中string的用法
  • 山东卓商网站建设公司做网站的广告词
  • uView2开发APP实现悬浮按钮
  • 让人做网站需要注意什么条件绍兴建设公司网站
  • OCSSA-VMD-Transformer-LSTM-Adaboost轴承故障诊断MATLAB代码实现
  • 工业园区废水除重金属镍
  • 自动化深度研究智能体-deep research实战
  • 制作网站培训学校网站建设优化服务方案模板
  • 计算机操作系统:文件保护
  • 卸载——通用方法
  • 【Java】异常