当前位置: 首页 > news >正文

神经网络之CNN文本识别

1.参考我的第一篇文章了解CNN概念

神经网络之CNN图像识别(torch api 调用)-CSDN博客

2.框架

目前对NLP的研究分析应用最多的就是RNN系列的框架,比如RNN,GRU,LSTM等等,再加上Attention,基本可以认为是NLP的标配套餐了。但是在文本分类问题上,相比于RNN,CNN的构建和训练更为简单和快速,并且效果也不差,所以仍然会有一些研究。

那么,CNN到底是怎么应用到NLP上的呢?

不同于CV输入的图像像素,NLP的输入是一个个句子或者文档。句子或文档在输入时经过embedding(word2vec或者Glove)会被表示成向量矩阵,其中每一行表示一个词语,行的总数是句子的长度,列的总数就是维度。例如一个包含十个词语的句子,使用了100维的embedding,最后我们就有一个输入为10x100的矩阵。

在CV中,filters是以一个patch(任意长度x任意宽度)的形式滑过遍历整个图像,但是在NLP中,filters会覆盖到所有的维度,也就是形状为 [filter_size, embed_size]。更为具体地理解可以看下图,输入为一个7x5的矩阵,filters的高度分别为2,3,4,宽度和输入矩阵一样为5。每个filter对输入矩阵进行卷积操作得到中间特征,然后通过pooling提取最大值,最终得到一个包含6个值的特征向量

 无疑就是卷积核大小变了长方形,再进行池化

3.api实现

self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸
self.num_filters = 256                                          # 卷积核数量(channels数)

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])#300X2 300x3 300x4 要求nn.Conv2d 要求输入的张量形状为 (N, C_in, H, W)
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)#(batch_size, num_filters, new_sequence_length,1)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)#(batch_size, num_filter,new_sequence_length)
        return x

    def forward(self, x):#return (x, seq_len)
        #print (x[0].shape)
        out = self.embedding(x[0])#(batch_size, sequence_length, embed_dim)
        out = out.unsqueeze(1) #(batch_size, 1,sequence_length, embed_dim)
        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

相关文章:

  • LeetCode 热题100 3. 无重复字符的最长子串
  • LabVIEW DataSocket 通信库详解
  • 基于DeepSeek 的图生文最新算法 VLM-R1
  • Go开发框架Sponge+AI助手协同配合重塑企业级开发范式
  • 论文阅读:CAN GENERATIVE LARGE LANGUAGE MODELS PERFORM ASR ERROR CORRECTION?
  • 【C语言显示Linux系统参数】
  • c++面试常见问题:虚表指针存在于内存哪个分区
  • LeetCodehot 力扣热题100 组合总和
  • 【C】初阶数据结构8 -- 链式二叉树
  • 计算机毕业设计SpringBoot+Vue.js人力资源管理系统(源码+文档+PPT+讲解)
  • MCP与RAG:增强大型语言模型的两种路径
  • 【算法】【并查集】acwing算法基础837. 连通块中点的数量
  • 每日一题——接雨水
  • 制作安装win10系统U盘详细步骤
  • 深入解析HDFS:定义、架构、原理、应用场景及常用命令
  • 【C++并发编程实战】第1章 你好,C++的并发世界!
  • Golang语言特性
  • C语言:51单片机 常用电子元器件讲解(带英文名称)
  • Java-servlet(一)Web应用与服务端技术概念知识讲解
  • Linux top 常用参数记录
  • 上海国际电影节将于6月3日公布排片表,6月5日中午开票
  • 零跑汽车一季度营收破百亿元:净亏收窄至1.3亿元,毛利率14.9%创新高
  • 世卫大会中国代表团:中国深入参与全球卫生治理,为构建人类卫生健康共同体贡献中国力量
  • 肖钢:一季度证券业金融科技投资强度在金融各子行业中居首
  • 内蒙古赤峰市城建集团董事长孙广通拟任旗县区党委书记
  • 俄乌刚谈完美国便筹划与俄乌领导人通话,目的几何?