当前位置: 首页 > news >正文

Week 21: 深度学习补遗:ViT Overview与手搓Multi-Head Attention

文章目录

  • Week 21: 深度学习补遗:ViT Overview与手搓Multi-Head Attention
    • 摘要
    • Abstract
    • 1. ViT概览
      • 1.1 **数据流**
      • 1.2 Discussion
    • 2. 手搓剖析 - Multi-Head Attention
      • 2.1 Code Overview
      • 2.2 Projection Matrix 映射矩阵
      • 2.3 计算前准备
      • 2.4 矩阵乘法与Mask
    • 总结

Week 21: 深度学习补遗:ViT Overview与手搓Multi-Head Attention

摘要

本周,阅读了ViT的文献以及进行了多头注意力机制的简单手搓,再次加强了对多头注意力机制与维度变化的理解,花了一定的时间解决了具体实践中维度变化理解的疑难杂症,将理论与实践进行了一定的联系,收获颇丰。

Abstract

This week, I reviewed the ViT original paper and conducted a hands-on experiment with the multi-head attention mechanism. This reinforced my understanding of how multi-head attention interacts with dimensionality changes. I spent considerable time resolving complex issues related to dimensionality in practical applications, successfully bridging theory and practice. The experience proved highly rewarding.

1. ViT概览

ViT是Vision Transformer的简称。和CNN的“利用卷积操作提取局部特征”区别,ViT的思想是,把一副图像拆成多个固定大小的块,转化为类似自然语言的Token,利用纯Transformer架构完成图像的全局建模,并且利用注意力机制进行长距离依赖关系的捕捉。

ViT Structure

1.1 数据流

  1. 将模型切分为固定大小的块(Patches)
  2. 将块展平成(Flatten)向量
  3. 加上位置嵌入编码(Position Embedding)
  4. 进入Transformer的Encoder
  5. 进入MLP头
  6. 获得分类结果

1.2 Discussion

ViT作为视觉Transformer的先驱,整体结构较为简单粗暴。

对于输入图像,将其进行切块后,经线性映射层对展平的Patch进行映射后,加上位置编码,直接进入Transformer的Encoder。最后经过一个MLP Head作为Classifier,直接得到分类输出。

其主要提出了一种将图像使用Transformer进行处理的思路,但因为其Patch之间的交互完全依赖Global Attention,而网络结构又较为简单粗暴,因此效果有所局限性,特别是在纹理提取上具有一定局限性。

2. 手搓剖析 - Multi-Head Attention

本周开始,为了加深对知识点的理解,对于经典的几个网络结构开始进行手搓,并仔细分析其维度变化以及实现细节。

2.1 Code Overview

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads) -> None:super().__init__()self.d_model = d_modelself.num_heads = num_heads# Q, K, V projections matrixself.W_q = nn.Linear(d_model, d_model) # d_model -> d_modelself.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)# Output projection matrixself.W_o = nn.Linear(d_model, d_model)# Softmaxself.softmax = nn.Softmax(dim=-1) # dim=-1 means last dimensiondef forward(self, q, k, v):batch, time, dimension = q.shape # q.shape = (batch, time, d_model)n_d = self.d_model // self.num_heads # self.d_model is the dimension of the model, n_d is the dimension of each head# arranging dimensions evenly across headsq, k, v = self.W_q(q), self.W_k(k), self.W_v(v)# projecting q, k, v to target spaceq = q.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)k = k.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)v = v.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)# .view() is used to reshape the tensor without changing the data# .permute() is used to rearrange the dimensions of the tensor# view: (batch, time, d_model) -> (batch, time, num_heads, n_d)# permute: (batch, time, num_heads, n_d) -> (batch, num_heads, time, n_d)# time * n_d is the dimension of the query and key vectors# reshape for parallel computationscore = q @ k.transpose(2, 3) / math.sqrt(n_d)# @ is matrix multiplication or dot product# * is element-wise multiplication# / math.sqrt(n_d) is scaling factor to normalize the dot productmask = torch.tril(torch.ones(time, time, dtype=bool))# torch.triu() is used to create a triangular matrix with upper triangular elements# torch.tril() is used to create a triangular matrix with lower triangular elements# torch.ones() is used to create a tensor of onesscore = score.masked_fill(mask == 0, float('-inf'))score = self.softmax(score) @ vscore = score.permute(0, 2, 1, 3).contiguous().reshape(batch, time, dimension)# .contiguous() is used to make the tensor contiguous in memoryreturn self.W_o(score)

2.2 Projection Matrix 映射矩阵

self.W_q = nn.Linear(d_model, d_model) # d_model -> d_model
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)

通过在__init__中定义三个线性层W_qW_kW_v,代表三个权重矩阵。

本处需要注意的细节是,d_model代表模型的维度,实际上d_model > num_heads * n_d,即n_d = d_model // num_heads。所以,三个线性层W_qW_kW_v其实是所有注意力头都拥有的独立现行投影矩阵的组合。

2.3 计算前准备

q, k, v = self.W_q(q), self.W_k(k), self.W_v(v)
# projecting q, k, v to target spaceq = q.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
k = k.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
v = v.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
# .view() is used to reshape the tensor without changing the data
# .permute() is used to rearrange the dimensions of the tensor
# view: (batch, time, d_model) -> (batch, time, num_heads, n_d)
# permute: (batch, time, num_heads, n_d) -> (batch, num_heads, time, n_d)# time * n_d is the dimension of the query and key vectors
# reshape for parallel computation

首先执行映射,将输入的QQQKKKVVV映射到对应的空间。

然后进行reshape,将(batch, time, d_model)展开为(batch, time, num_heads, n_d),再进行索引改变,变为(batch, num_heads, time, n_d),这样执行矩阵乘法后维度就会变成(batch, num_heads, time, time)

2.4 矩阵乘法与Mask

score = q @ k.transpose(2, 3) / math.sqrt(n_d)
# @ is matrix multiplication or dot product
# * is element-wise multiplication
# / math.sqrt(n_d) is scaling factor to normalize the dot productmask = torch.tril(torch.ones(time, time, dtype=bool))
# torch.triu() is used to create a triangular matrix with upper triangular elements
# torch.tril() is used to create a triangular matrix with lower triangular elements
# torch.ones() is used to create a tensor of onesscore = score.masked_fill(mask == 0, float('-inf'))
score = self.softmax(score) @ vscore = score.permute(0, 2, 1, 3).contiguous().reshape(batch, time, dimension)
# .contiguous() is used to make the tensor contiguous in memoryreturn self.W_o(score)

一些重要函数的意义在主时钟已经给出,此处不再赘述,仅叙述关键数学思路。

k.transpose(2,3)意义为转置kkk的第二、第三维度的矩阵,即从(batch, num_heads, time, n_d)变为(batch, num_heads, n_d, time)。执行score = q @ k.transpose(2, 3) / math.sqrt(n_d)后,即完成了qqqkkk的点积,并进行了归一化。

利用torch.tril(torch.ones(time, time, dtype=bool))生成了一个下三角矩阵,维度为(time, time),用bool类型的1填充。并利用.masked_fill(mask == 0, float('-inf'))将所有为0的元素变为负无穷-inf

易知,在第一行只有(0,0)是1,其余都是-inf,相乘后都为-inf,这样就完成了Mask的操作,即注意力在0时刻只能看到(0,0)

最后,执行.permute(0, 2, 1, 3)将索引恢复,并进行reshape回输入的尺寸即可输出。

总结

本周对ViT模型的文献进行了阅读,开始尝试理解模型的结构和思考其实现方式,对ViT的历史成就与局限性进行了分析和了解。同时,花费一定时间手搓了Multi-Head Attention,后续也将继续对比较重要的机制进行手搓以加强理解。本周在学习中发现数学中简要表达的某些计算过程在实践中的维度变化更加难以理解,还是需要一定的结合实践进行深入理解,本周收获较为丰富。

http://www.dtcms.com/a/504725.html

相关文章:

  • asp 网站开发兰州seo优化入门
  • aop之agent增强
  • 历史权重查询百家号优化上首页
  • wdcp网站无法访问宁波优化网站排名软件
  • 网上做的好金融网站微信网站开发设计
  • 13-原码、反码、补码
  • 数据压缩与解压
  • 使用MQ解耦点赞通知功能
  • 青岛高品质网站制作window服务器如何做网站访问
  • 机械加工网站易下拉大测网站建设的步骤及方法
  • 《AI的未来:从“召唤幽灵”到学会反思》
  • 特性设计的网站网站建设安全方案
  • 英文公司网站制作wordpress 4.7.1
  • 烟台教育平台网站建设手机网站的尺寸做多大的
  • 获取网站漏洞后下一步怎么做初中做语文题的网站
  • 明星网站策划书阿里云 ip 网站
  • 网站设计昆明ipv6网站建设
  • 数据脱敏:Google DLP API,敏感模式自动识别?
  • 做网站前端用什么技术好网站建设后怎么赚钱
  • 2025-10-19 hetao1733837刷题记录 Ⅱ
  • 网站推广策略与问题分析wordpress 海 主题
  • 音平商城谁做的网站做的好的茶叶网站好
  • 【算法】队列 + 宽度优先搜索
  • 机器学习周报十八
  • C# 参数详解:从基础传参到高级应用
  • 棠下网站建设jsp如何进行购物网站开发
  • 惠州做网站公司部门门户网站建设请示
  • 分析一下Xshell效率实战——SSH管理秘籍
  • 怎么做卖橘子的网站做网站编辑我能力得到提升
  • 支付通道网站怎么做网页设计代码计算器