当前位置: 首页 > news >正文

PyTorch 实战:从 0 开始搭建 Transformer

  1. 导入必要的库

python

import math
import torch
import torch.nn as nn
from LabmL_helpers.module import Module
from labml_n.utils import clone_module_List
from typing import Optional, List
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
import torch.nn.functional as F
  1. Transformer 模型概述
    Transformer 是一种序列到序列的模型,通过自注意力机制并行处理整个序列,能同时考虑序列中的所有元素,并学习上下文之间的关系。其架构包括编码器和解码器部分,每部分都由多个相同的层组成,这些层包含自注意力机制、前馈神经网络,以及归一化和 Dropout 步骤。
  2. 核心公式
    • 自注意力计算:Attention(Q,K,V)=softmax(dk​​QKT​)V,其中,Q、K、V分别是查询(Query)、键(Key)和值(Value)矩阵,dk​是键的维度。
    • 多头注意力:将输入分割为多个头,分别计算注意力,然后将结果拼接起来。
    • 位置编码:由于 Transformer 不使用循环结构,因此引入位置编码来保留序列中的位置信息。
  3. 自注意力机制
    • 核心原理:计算句子在编码过程中每个位置上的注意力权重,然后以权重和的方式来计算整个句子的隐含向量表示。公式中,首先将 query 与 key 的转置做点积,然后将结果除以dk​​ ,再进行 softmax 计算,最后将结果与 value 做矩阵乘法得到 output。除以dk​​是为了防止QKT过大导致 softmax 计算溢出,且可使QKT结果满足均值为 0,方差 1 的分布。QKT计算本质上是余弦相似度,可表示两个向量在方向上的相似度。
    • 实现

python

import numpy as np
from math import sqrt
import torch
from torch import nnclass Self_Attention(nn.Module):# input : batch_size * seq_len * input_dim# q : batch_size * input_dim * dim_k# k : batch_size * input_dim * dim_k# v : batch_size * input_dim * dim_vdef __init__(self, input_dim, dim_k, dim_v):super(Self_Attention, self).__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self._norm_fact = 1 / sqrt(dim_k)def forward(self, x):Q = self.q(x)  # Q: batch_size * seq_len * dim_kK = self.k(x)  # K: batch_size * seq_len * dim_kV = self.v(x)  # V: batch_size * seq_len * dim_v# Q * K.T() # batch_size * seq_len * seq_lenatten = nn.Softmax(dim=-1)(torch.bmm(Q, K.permute(0, 2, 1))) * self._norm_fact# Q * K.T() * V # batch_size * seq_len * dim_voutput = torch.bmm(atten, V)return outputX = torch.randn(4, 3, 2)
print(X)
self_atten = Self_Attention(2, 4, 5)  # input_dim:2, k_dim:4, v_dim:5
res = self_atten(X)
print(res.shape)  # [4,3,5]

  1. 多头注意力机制
    不同于只使用一个注意力池化,将输入x拆分为h份,独立计算h组不同的线性投影来得到各自的 QKV,然后并行计算注意力,最后将h个注意力池化拼接起来并通过另一个可学习的线性投影进行变换以产生输出。每个头可能关注输入的不同部分,可表示更复杂的函数。

python

from math import sqrt
import torch
import torch.nn as nnclass Self_Attention_Muti_Head(nn.Module):# input : batch_size * seq_len * input_dim# q : batch_size * input_dim * dim_k# k : batch_size * input_dim * dim_k# v : batch_size * input_dim * dim_vdef __init__(self, input_dim, dim_k, dim_v, nums_head):super(Self_Attention_Muti_Head, self).__init__()assert dim_k % nums_head == 0assert dim_v % nums_head == 0self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.nums_head = nums_headself.dim_k = dim_kself.dim_v = dim_vself._norm_fact = 1 / sqrt(dim_k)def forward(self, x):Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //self.nums_head)K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //self.nums_head)V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v //self.nums_head)print(x.shape)print(Q.size())atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0, 1, 3, 2)))  # Q * K.T() # batch_size * seq_len * seq_lenoutput = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1)  # Q * K.T() * V # batch_size * seq_len * dim_vreturn outputx = torch.rand(1, 3, 4)
print(x)
atten = Self_Attention_Muti_Head(4, 4, 4, 2)
y = atten(x)
print(y.shape)

  1. 视觉注意力机制
    attention 机制本质是利用相关特征图学习权重分布,再用学出来的权重施加在原特征图上最后进行加权求和。计算机视觉上的注意力机制主要分为三种:空间域、通道域、混合域。
    • 空间域:将图片中的空间域信息做对应的空间变换,提取关键信息,对空间进行掩码的生成并打分,代表是 Spatial attention module。
    • 通道域:给每个通道上的信号增加一个权重,代表该通道与关键信息的相关度,权重越大相关度越高。对通道生成掩码 mask 进行打分,代表是 senet、channel attention module。
    • 混合域:空间域的注意力忽略了通道域中的信息,将每个通道的图片特征同等处理,这种做法会将空间域变换方法局限在原始特征提取阶段。
  2. 通道域注意力(SENet)
    通过全局池化提取通道权重,然后对特征图进行改变,得到加强后的特征图。

python

class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)  # 对应Squeeze操作y = self.fc(y).view(b, c, 1, 1)  # 对应Excitation操作return x * y.expand_as(x)

  1. 门控注意力机制(GCT,Gated Channel Transformation)
    GCT 是一种简单有效的通道间建模关系体系结构,能显著提高卷积网络在视觉任务的泛化能力。论文发现将门控机制放在 Conv 层前面训练效果最好。GCT 包含三个部分:
    • Global Context Embedding:设计了一种全局上下文嵌入模块,用于每个通道的全局上下文信息汇聚,公式为sc​=αc​∥xc​∥2​=αc​{[∑i=1H​∑j=1W​(xci,j​)2]+ϵ}21​。
    • Channel Normalization:对第一步计算的 L2 进行规范化来构建神经元竞争关系,使用跨通道的特征规范化,公式为s^c​=∥s∥2​C​sc​​=[(∑c=1C​sc2​)+ϵ]21​C​sc​​。
    • Gating Adaptation:加入门限机制,公式为x^c​=xc​[1+tanh(γc​s^c​+βc​)] 。

python

class GCT(nn.Module):def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):super(GCT, self).__init__()self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))self.epsilon = epsilonself.mode = modeself.after_relu = after_reludef forward(self, x):if self.mode == 'l2':embedding = (x.pow(2).sum((2, 3), keepdim=True) +self.epsilon).pow(0.5) * self.alphanorm = self.gamma / \(embedding.pow(2).mean(dim=1, keepdim=True) +self.epsilon).pow(0.5)elif self.mode == 'l1':if not self.after_relu:_x = torch.abs(x)else:_x = xembedding = _x.sum((2, 3), keepdim=True) * self.alphanorm = self.gamma / \(torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)gate = 1. + torch.tanh(embedding * norm + self.beta)return x * gate

GCT 建议添加在 Conv 层前,一般可以先冻结原来的模型,来训练 GCT,然后解冻再进行微调。

相关文章:

  • 按句子切分文本、保留 token 对齐信息、**适配 tokenizer(如 BERT)**这种需求
  • 数据中台-常用工具组件:DataX、Flink、Dolphin Scheduler、TensorFlow和PyTorch等
  • 计算机视觉与深度学习 | 基于Transformer的低照度图像增强技术
  • 从知识图谱到精准决策:基于MCP的招投标货物比对溯源系统实践
  • 【银河麒麟高级服务器操作系统】服务器外挂存储ioerror分析及处理分享
  • flinksql bug : Max aggregate function does not support type: CHAR
  • Debian系统详解
  • UV使用官网
  • 【C语言】--指针超详解(二)
  • 基于Kubernetes的Apache Pulsar云原生架构解析与集群部署指南(上)
  • 408考研逐题详解:2009年第10题
  • 美化IDEA注释:Idea 中快捷键 Ctrl + / 自动注释的缩进(避免添加注释自动到行首)以及 Ctrl + Alt + l 全局格式化代码的注释缩进
  • 从0到1:用Lask/Django框架搭建个人博客系统(4/10)
  • IT/OT 融合架构下的工业控制系统安全攻防实战研究
  • 美化cmd窗格----添加背景图
  • 一文读懂Nginx应用之 HTTP负载均衡(七层负载均衡)
  • 软考知识点汇总
  • 【C++】手搓一个STL风格的string容器
  • 数字孪生市场格局生变:中国2025年规模214亿,工业制造领域占比超40%
  • SpringAI实现AI应用-使用redis持久化聊天记忆
  • 洲际酒店:今年第一季度全球酒店平均客房收入同比增长3.3%
  • 梅花奖在上海|第六代“杨子荣”是怎样炼成的?
  • 大四本科生已发14篇SCI论文?重庆大学:成立工作组核实
  • 2024年上市公司合计实现营业收入71.98万亿元
  • 两部门部署中小学幼儿园教师招聘工作:吸纳更多高校毕业生从教
  • 光大华夏:近代中国私立大学遥不可及的梦想