当前位置: 首页 > news >正文

层在init中只为创建线性层,forward的对线性层中间加非线性运算。且分层定义是为了把原本一长个代码的初始化和运算放到一个组合中。

init注意有几个层,这里有四个层,接下来在forward函数中会把这四个层都用上

model = SimpleViT().to(device),创建模型的时候,会弄出所有层

当outputs = model(images),模型被使用的时候才调用各个定义层时的forward函数,调用顺序如下,注意,每个forward都完全使用了这些层,且以mha多头注意力层为例,其中包含把输入变成qkv大矩阵的性层和其他线性层,线性层之间还有其他的非线性操作,会在forward使用,所以forward不是简单的把线形层组合,其中层与层之间的非线性运算就在这里

为什么要定义这么多类和forward函数呢?就是为了把一长串代码中的初始化放到整个类的初始化中,然后操作放到这个类定义的方法中

举例

对于这个类

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()assert d_model % num_heads == 0self.head_dim = d_model // num_headsself.num_heads = num_headsself.qkv = nn.Linear(d_model, d_model * 3)self.out = nn.Linear(d_model, d_model)def forward(self, x, *_):B, seq_len, d_model = x.shapeqkv = self.qkv(x)qkv = qkv.view(B, seq_len, 3, self.num_heads, self.head_dim)q, k, v = qkv.unbind(dim=2)q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)attn = torch.softmax(scores, dim=-1)context = torch.matmul(attn, v)context = context.transpose(1, 2).reshape(B, seq_len, d_model)return self.out(context)

可以直接使用

mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
y_class = mha(x)

这等效为代码,

# 线性层
qkv_linear = nn.Linear(d_model, d_model*3)
out_linear = nn.Linear(d_model, d_model)# ----------------------------
# Multi-Head Attention 流程
# ----------------------------# 1. 线性映射得到 QKV
qkv = qkv_linear(x)  # (B, seq_len, 3*d_model)# 2. reshape 成 (B, seq_len, 3, num_heads, head_dim)
qkv = qkv.view(B, seq_len, 3, num_heads, head_dim)# 3. 拆分 q, k, v
q, k, v = qkv.unbind(dim=2)  # 每个 (B, seq_len, num_heads, head_dim)# 4. 转置到 (B, num_heads, seq_len, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)# 5. 注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / (head_dim ** 0.5)  # (B, num_heads, seq_len, seq_len)# 6. softmax 得到注意力权重
attn = torch.softmax(scores, dim=-1)# 7. 加权求和得到上下文
context = torch.matmul(attn, v)  # (B, num_heads, seq_len, head_dim)# 8. 转置回 (B, seq_len, num_heads, head_dim)
context = context.transpose(1, 2)# 9. 拼回 d_model
context = context.reshape(B, seq_len, d_model)# 10. 输出线性映射
out = out_linear(context)  # (B, seq_len, d_model)

http://www.dtcms.com/a/341139.html

相关文章:

  • B站 韩顺平 笔记 (Day 24)
  • C++ std::optional 深度解析与实践指南
  • 当 AI 开始 “理解” 情绪:情感计算如何重塑人机交互的边界
  • linux报permission denied问题
  • Advanced Math Math Analysis |01 Limits, Continuous
  • uniapp打包成h5,本地服务器运行,路径报错问题
  • PyTorch API 4
  • 使数组k递增的最少操作次数
  • 路由器的NAT类型
  • 确保测试环境一致性与稳定性 5大策略
  • AI 效应: GPT-6,“用户真正想要的是记忆”
  • 获取本地IP地址、MAC地址写法
  • SQL 中大于小于号的表示方法总结
  • Bitcoin有升值潜力吗
  • 《代码沙盒深度实战:iframe安全隔离与实时双向通信的架构设计与落地策略》
  • 在SQL中使用大模型时间预测模型TimesFM
  • Mybatis执行SQL流程(五)之MapperProxy与MapperMethod
  • zoho crm api 无法修改富文本字段的原因:api 版本太低
  • 23种设计模式——构建器模式(Builder Pattern)详解
  • Spring Boot Controller 使用 @RequestBody + @ModelAttribute 接收请求
  • 车联网(V2X)中万物的重新定义---联网汽车新时代
  • Dubbo 的 Java 项目间调用的完整示例
  • 分析NeRF模型中颜色计算公式中的参数
  • Paraformer实时语音识别中的碎碎念
  • RuntimeError: Dataset scripts are no longer supported, but found wikipedia.py
  • 车辆订单状态管理的优化方案:状态机设计模式
  • 从ioutil到os:Golang在线客服聊天系统文件读取的迁移实践
  • 从零开发Java坦克大战Ⅱ(上) -- 从单机到联机(架构演进与设计模式剖析)
  • 音频大模型学习笔记
  • CS+ for CC编译超慢的问题该如何解决