当前位置: 首页 > news >正文

五、CV_ResNet

五、ResNet

随着层数加深,特征图的尺寸是逐渐变小的,其通道是逐渐增多的

网络退化问题:理论上,网络越深,获取的信息就越多,特征也就越丰富。但在实践中,随着网络的加深,优化效果反而越差,测试数据和训练数据的准确率反而降低了

1.残差块

(1)作用

  • 缓解网络退化问题
  • 在模型中是用来降维的

(2)概念

F(x)F(x)F(x)代表某个只包含两层的映射函数,xxxF(x)F(x)F(x)具有相同维度。在训练过程中,我们的目标是修改F(x)F(x)F(x)中的wwwbbb逼近H(x)H(x)H(x),变换一下思路,用F(x)F(x)F(x)来逼近,则最终得到的输H(x)−xH(x)-xH(x)x出就变为F(x)+xF(x)+xF(x)+x,这里将直接从输入连接到输出的结构称为,shortcutshortcutshortcut整个结构就是残差块,ResNetResNetResNet的基础模块。

  • F(x)F(x)F(x)代表预测值
  • H(x)H(x)H(x)代表真实值
  • F(x)+xF(x)+xF(x)+x:这里的加指的是对应位置上的元素相加,也就是$element - wise $additionadditionaddition

ResNetResNetResNet沿用了VGG全3×33 \times 33×3卷积层的设计。残差块里首先有2个具有相同输出通道数的3×33 \times 33×3卷积层。每个卷积层后接BN层和ReLU激活函数,然后将输入直接加在最后的ReLU激活函数前,这种结构用于层数较少的神经网络中(resnet18, resnet34)

如果输入通道数(eg.eg.eg.图中有256通道数)比较多,就需要引入1×11 \times 11×1卷积层来调整输入的通道数,这种结构也叫做瓶颈模块,通常用关于网络层数较多的结构中(resnet50,resnet101,resnet152)

下面右图残差块的实现如下,可以设定输出通道数,是否使用1×11 \times 11×1的卷积及卷积层的步幅

import tensorflow as tf
from tensorflow.keras import layers, activations# 定义ResNet
class Residual(tf.keras.Model):# 指明残差块的通道数,是否使用1*1卷积,步长def __init__(self, num_channels,use_1x1convs = False, strides = 1):super(Residual, self).__init__()# 卷积层:指明卷积核个数,padding,卷积核大小,步长self.cov1 = layers.Conv2D(num_channels, padding = 'same', kernel_size = 3,strides = strides)# 卷积层:指明卷积核个数,padding,卷积核大小,步长self.conv2 = layers.Conv2D(num_channels,strides = 1,kernel_size = 3,padding = 'same')if use_1x1conv:self.conv3 = layers.Conv2D(num_channels,kernel_size = 1,strides = strides)else:self.conv3 = None# 指明BN层self.bn1 = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()# 定义正向传播过程def call(self, X):# 卷积,BN, 激活Y = activations.relu(self.bn1(self.conv1(x)))# 卷积,BNY = self.bn2(self.conv2(Y))# 对输出数据进行1*1卷积保证通道数相同if self.conv3:X = self.conv3(X)# 返回与输入相加后激活的结果return activation.relu(Y + X)
  • 1×11\times 11×1卷积是用来调整通道数
  • 降维
    • pooling层
    • 设置卷积 strides = 2

2.ResNet模型

ResNet模型构成如下:

ResNet网络中按照残差块的通道数分为不同的模块。第一个模块使用了步幅为2的最大池化层。则无需减小宽和高(第一个模块需进行特殊处理。即需要进行下采样,降维)。之后每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半(每个模块间均需进行通道调整)

(1)定义残差模块

  • 第一个模块做了特别处理
class ResnetBlock(tf.keras.layers.Layer):def __init__(self, num_channels, num_res, first_block = False):super(ResnetBlock, self).__init__()# 存储残差块self.listLayers = []# 遍历残差数目生成模块for i in range(num_res):if i == 0 and not first_block:self.listLayers.append(Residual(num_channels, use_1x1conv = True, strides = 2))else:self.listLayers.append(Residual(num_channels))# 前向传播def call(self, x):for layer in self.listLayers:x = layer(x)return x                

(2)构建Resnet网络

  • ResNet的前两层跟之前介绍的GoogLeNet中一样:在输出通道数为64,步幅为2的7×77 \times 77×7卷积层后接步幅为2的3×33\times33×3的最大池化层。不同之处在于ResNet每个卷积层后增加了BN层,接着是所有残差模块,最后,与GoogLeNet一样,加入全局平均池化层(GPA)后接上全连接层输出
class Resnet(tf.keras.Model):# 定义网络的构成def __init__(self, num_blocks):super(Resnet, self).__init__()# 输入层self.conv = layers.Conv2D(filter = 64,kernel_size = 7,padding = 'same',strides = 2)# BN层self.bn = layers.BatchNormalization()# 激活层self.relu = layers.Activation('relu')# 池化self.mp = layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'same')# 残差模块self.res_block1 = ResnetBlock(64, num_blocks[0], first_block = True)self.res_block2 = ResnetBlock(128, num_blocks[1])self.res_block3 = ResnetBlock(256, num_blocks[2])self.res_block4 = ResnetBlock(512, num_blocks[3])# GAPself.gap = layers.GlobalAvgPool2D()# 全连接层self.fc = layers.Dense(units = 10, activation = tf.keras.activations.softmax)# 定义前向传播过程def call(self,x):# 输入部分传输过程x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.mp(x)# blockx = self.res_block1(x)x = self.res_block2(x)x = self.res_block3(x)x = self.res_block4(x)# 输出部分的传输x = self.gap(x)x = self.fc(x)return x     

这里每个模块里有4个卷积层(不计算11卷积层),加上最开始的卷积层和最后的全连接层,共计18层。这个模型被称为ResNet-18。通过配置不同的通道数给模块里的残差块数可以得到不同的ResNet模型。虽然ResNet的主体架构跟GoogLeNet的类似,但ResNet结构更简单,修改也更方便。

# 实例化
mynet = Resnet([2, 2, 2, 2])
x = tf.random.uniform((1, 224, 224, 1))
y = mynet(x)
mynet.summary()

最终可以得到ResNet的架构

3.手写数字识别

(1)数据读取

获取数据并进行维度调整

import numpy as np
from tensorflow.keras.datasets import mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# N H W C
train_images = np.reshape(train_images, (train_images.shape[0], train_images.shape[1], train_images.shape[2], 1))test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))

定义两个方法获取部分数据

# 定义两个方法随机抽取部分样本演示def get_train(size):index = np.random.randint(0, np.shape(train_images)[0], size)resize_images = tf.image.resize_with_pad(train_images[index], 224, 224, )return resize_images.numpy(), train_labels[index]def get_test(size):index = np.random.randint(0, np.shape(test_images)[0], size)resize_images = tf.image.resize_with_pad(test_images[index], 224, 224, )return resize_images.numpy(), test_labels[index]
# 获取训练样本和测试样本
train_image, train_label = get_train(256)
test_image, test_label = get_test(128)

(2)模型编译

# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.01, momentum = 0.0)mynet.compile(optimizer = optimizer,loss = 'sparse_categorical_crossentropy',metrics = ['accuracy']
)

(3)模型训练

# 模型训练:指定训练数据集,batchsize, epoch, 验证集
mynet.fit(train_images, train_labels, batch_size = 128, epochs = 3, verbose = 1, # 显示整个训练的logvalidation_split = 0.2) # 验证集

(4)模型评估

mynet.evaluate(test_images, test_labels, verbose = 1)
http://www.dtcms.com/a/322118.html

相关文章:

  • Redis的Linux安装
  • python-操作mysql数据库(增删改查)
  • 医疗设备专用电源滤波器的安全设计与应用价值|深圳维爱普
  • Python接口测试实战之搭建自动化测试框架
  • 文学主题的演变
  • 智慧养老场景识别率↑91%!陌讯轻量化模型在独居监护的落地优化
  • 根据ASTM D4169-23e1标准,如何选择合适的流通周期进行测试?
  • 大语言模型的过去与未来——GPT-5发布小谈
  • 机器学习概念1
  • 数据库基础--多表关系,多表查询
  • 中小型企业ERP实施成本高?析客ERP系统独立部署+模块化配置的务实解决方案
  • ENET_GetRxFrame vs ENET_ReadFrame
  • MySQL 正则表达式详细说明
  • AI摄像头动捕:让动作分析更自由、更智能、更高效
  • 刚刚,GPT-5 炸裂登场!可免费使用
  • Word中怎样插入特殊符号
  • seo-使用nuxt定义页面标题和meta等信息
  • 11_Mybatis 是如何进行DO类和数据库字段的映射的?
  • HTTP/HTTPS代理,支持RSA和SM2算法
  • 消防通道占用识别误报率↓79%:陌讯动态区域感知算法实战解析
  • 自签名证书实现HTTPS协议
  • 17.14 CogVLM-17B多模态模型爆肝部署:4-bit量化+1120px高清输入,A100实战避坑指南
  • 登上Nature子刊,深度学习正逐渐接管基础模型
  • NY128NY133美光固态闪存NY139NY143
  • 智驭全球波动:跨境量化交易系统2025解决方案
  • Linux系统:Ext系列文件系统(硬件篇)
  • 专题二_滑动窗口_将x减到0的最小操作数
  • Dart 单例模式:工厂构造、静态变量与懒加载
  • 频谱图学习笔记
  • python 通过Serper API联网搜索并大模型整理内容