当前位置: 首页 > news >正文

【CVPR 2022】面向2020年代的卷积神经网络

文章目录

  • 一、论文信息
  • 二、论文概要
  • 三、实验动机
  • 四、创新之处
  • 五、实验分析
  • 六、核心代码
    • 源代码
    • 注释版本
  • 七、实验总结

一、论文信息

  • 论文题目:A ConvNet for the 2020s
  • 中文题目:面向2020年代的卷积神经网络
  • 论文链接:点击跳转
  • 代码链接:点击跳转
  • 作者:Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie(刘壮、毛汉孜、吴超元、Christoph Feichtenhofer、Trevor Darrell、谢赛宁)
  • 单位:Facebook AI Research (FAIR), UC Berkeley
  • 核心速览:提出ConvNeXt,通过现代化ResNet设计,使纯CNN在多个视觉任务上媲美甚至超越Swin Transformer。

二、论文概要

该论文提出了一种新的卷积神经网络(ConvNet)架构,命名为ConvNeXt,旨在重新审视传统卷积网络,并改进其设计以在当前的视觉任务中竞争力更强。研究通过逐步对标准ResNet进行“现代化”改造,探索了多个关键设计元素的影响,最终得出了ConvNeXt,这一基于标准卷积神经网络模块构建的网络能够在多个计算机视觉任务上与Transformers架构相竞争,表现出优异的准确性和可扩展性。

三、实验动机

  • Vision Transformer(ViT)及其变体(如Swin)在多个视觉任务上表现优异,许多人认为其成功源于自注意力机制的优势。

  • 作者质疑这一观点,认为许多Transformer的成功设计实际上可被卷积网络吸收。

  • 希望通过“现代化”ResNet,探索纯CNN的潜力,挑战“Transformer必然优于CNN”的成见。

四、创新之处

  • ConvNeXt架构设计:通过逐步“现代化”ResNet,采用了更大的卷积核、更有效的网络宽度分配、GELU激活函数替代ReLU等方式,增强了卷积神经网络的表达能力和可扩展性。

  • 设计元素结合:融合了Transformers架构的设计理念(如大卷积核、分阶段计算、逐步加宽网络等),但不使用注意力机制,保持了ConvNets的简洁性和高效性。

  • 性能优化:ConvNeXt在多个视觉任务中超越了传统的卷积网络,并与ViT和Swin Transformer等先进的视觉Transformer模型在性能上相媲美,且在推理速度和内存使用上具有优势。

五、实验分析

实验表明,ConvNeXt在多个计算机视觉任务上表现优异,尤其是在ImageNet分类、COCO物体检测、ADE20K语义分割等任务中,能够与Swin Transformer等视觉Transformer架构竞争,并且具有较高的推理吞吐量和较低的内存消耗。此外,ConvNeXt也表现出了良好的可扩展性,尤其是在使用大规模数据集(如ImageNet-22K)进行预训练时,其性能有显著提升。

六、核心代码

源代码

class ConvNeXt(nn.Module):r""" ConvNeXtA PyTorch impl of : `A ConvNet for the 2020s`  -https://arxiv.org/pdf/2201.03545.pdfArgs:in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]drop_path_rate (float): Stochastic depth rate. Default: 0.layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1."""def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1.,):super().__init__()self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layersstem = nn.Sequential(nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))self.downsample_layers.append(stem)for i in range(3):downsample_layer = nn.Sequential(LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),)self.downsample_layers.append(downsample_layer)self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocksdp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0for i in range(4):stage = nn.Sequential(*[Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])])self.stages.append(stage)cur += depths[i]self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layerself.head = nn.Linear(dims[-1], num_classes)self.apply(self._init_weights)self.head.weight.data.mul_(head_init_scale)self.head.bias.data.mul_(head_init_scale)def _init_weights(self, m):if isinstance(m, (nn.Conv2d, nn.Linear)):trunc_normal_(m.weight, std=.02)nn.init.constant_(m.bias, 0)def forward_features(self, x):for i in range(4):x = self.downsample_layers[i](x)x = self.stages[i](x)return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)def forward(self, x):x = self.forward_features(x)x = self.head(x)return xclass LayerNorm(nn.Module):r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width)."""def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedError self.normalized_shape = (normalized_shape, )def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return x

注释版本


七、实验总结

  • ConvNeXt通过结合Transformer的先进设计思想,成功地提升了传统卷积神经网络的性能,同时保持了其结构的简洁性和高效性。
  • 实验表明,在多个视觉任务中,ConvNeXt与当前最强的视觉Transformer模型(如Swin Transformer)相比,不仅表现出相似的性能,而且在推理速度和内存使用上具有优势,展示了卷积神经网络在新时代的强大潜力。

文章转载自:

http://mtO0e11N.ptLwt.cn
http://rW8HZvR1.ptLwt.cn
http://zPb46Tu1.ptLwt.cn
http://pJuE7GD1.ptLwt.cn
http://V5MxNgsY.ptLwt.cn
http://XFdpjrbt.ptLwt.cn
http://OCbBEccI.ptLwt.cn
http://cdcBPtfo.ptLwt.cn
http://WuilB2d5.ptLwt.cn
http://VfGxasEf.ptLwt.cn
http://QsdRsm5j.ptLwt.cn
http://usRhCSSo.ptLwt.cn
http://Q55hGaLt.ptLwt.cn
http://Zo0HhgP8.ptLwt.cn
http://cmnjX1wU.ptLwt.cn
http://D0yrB7Pu.ptLwt.cn
http://UqZrhZzi.ptLwt.cn
http://fFmr2gGf.ptLwt.cn
http://WFKy9n8n.ptLwt.cn
http://dEvu3eAx.ptLwt.cn
http://CindkpC2.ptLwt.cn
http://r86ptuI2.ptLwt.cn
http://ddcymsUy.ptLwt.cn
http://vO7hdHmR.ptLwt.cn
http://0FHNITd8.ptLwt.cn
http://QPTDXkHW.ptLwt.cn
http://U0ErQ2zb.ptLwt.cn
http://qs7MngZQ.ptLwt.cn
http://whfQIXZX.ptLwt.cn
http://ANLlIKcx.ptLwt.cn
http://www.dtcms.com/a/376588.html

相关文章:

  • 图神经网络介绍
  • FPGA入门到进阶:可编程逻辑器件的魅力
  • 【解决问题】Ubuntu18上无法运行arm-linux-gcc
  • 嵌入式学习day47-硬件-imx6ull-LED
  • 深入体验—Windows从零到一安装KingbaseES数据库
  • 力扣习题——电话号码的字母组合
  • Linux环境下爬虫程序的部署难题与系统性解决方案
  • 深入解析ThreadLocal:线程数据隔离利器
  • D01-【计算机二级】Python(1)基本操作第41题
  • API开发工具postman、国内xxapi和SmartApi的性能对比
  • Scikit-learn Python机器学习 - 分类算法 - 线性模型 逻辑回归
  • SciKit-Learn 全面分析 digits 手写数据集
  • 《sklearn机器学习——数据预处理》标准化或均值去除和方差缩放
  • 保序回归Isotonic Regression的sklearn实现案例
  • 《sklearn机器学习——数据预处理》离散化
  • 无人机桨叶转速技术要点与突破
  • GPFS存储服务如何使用及运维
  • ELK 日志采集与解析实战
  • BI数据可视化:驱动数据价值释放的关键引擎
  • FinChat-金融领域的ChatGPT
  • OpenTenBase日常操作锦囊(新手上路DML)
  • Dart 中的 Event Loop(事件循环)
  • C++/Java编程小论——方法设计与接口原则总结
  • Java-Spring入门指南(四)深入IOC本质与依赖注入(DI)实战
  • 线扫相机采集图像起始位置不正确原因总结
  • JVM 对象创建的核心流程!
  • 秋日私语:一片落叶,一个智能的温暖陪伴
  • springCloud之配置/注册中心及服务发现Nacos
  • 第1讲 机器学习(ML)教程
  • Ubuntu 系统 YOLOv8 部署教程(GPU CPU 一键安装)