当前位置: 首页 > news >正文

自注意力与交叉注意力的PyTorch 简单实现

自注意力(Self-Attention)与交叉注意力(Cross-Attention)PyTorch 简单实现

在深度学习中,注意力机制是现代 Transformer 架构的核心思想之一。本文将介绍两种常见的注意力机制:自注意力(Self-Attention)交叉注意力(Cross-Attention),并通过 PyTorch 给出简单实现与使用示例。


📦 必要导入

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange, repeat
from inspect import isfunction

# 一些基础工具函数
def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

🔍 什么是注意力机制?

注意力机制允许模型在处理输入序列时自动聚焦于最相关的部分,从而增强建模能力。以 Transformer 为例,它通过注意力机制建立了序列中不同位置之间的信息关联。


🤖 自注意力(Self-Attention)

自注意力是指查询(Query)、键(Key)、值(Value)都来自同一个输入序列。这种机制允许序列中的每个元素关注其它所有位置的信息,是 BERT、GPT 等模型的基本构件。

✅ PyTorch 实现:

class SelfAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None):
        h = self.heads
        qkv = self.to_qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))

        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale

        if exists(mask):
            ...

        attn = sim.softmax(dim=-1)
        out = einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
        return self.to_out(out)

🧪 使用示例:

attn = SelfAttention(dim=64)
x = torch.randn(1, 10, 64)
out = attn(x)
print(out.shape)  # torch.Size([1, 10, 64])

🔁 交叉注意力(Cross-Attention)

交叉注意力允许模型在处理一个序列时,从另一个序列中获取信息。常用于:

  • 编码器-解码器结构(如 Transformer 翻译模型)
  • 图文跨模态对齐
  • 条件生成任务

与自注意力的不同在于:

  • Query 来自当前输入(例如解码器)
  • Key 与 Value 来自另一个序列(例如编码器)

✅ PyTorch 实现:

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads
        context = default(context, x)

        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale

        if exists(mask):
            ...

        attn = sim.softmax(dim=-1)
        out = einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)

        return self.to_out(out)

🧪 使用示例:

ca = CrossAttention(query_dim=64, context_dim=77)
x = torch.randn(1, 10, 64)          # 解码器输入
context = torch.randn(1, 20, 77)    # 编码器输出
out = ca(x, context)
print(out.shape)  # torch.Size([1, 10, 64])

🧠 总结对比

模块Query 来自Key/Value 来自典型应用
Self-Attention当前输入当前输入BERT、GPT、自注意图像建模
Cross-Attention当前输入外部上下文编解码结构、跨模态、条件生成

🔗 参考链接

Stable Diffusion CrossAttention 代码实现

http://www.dtcms.com/a/106899.html

相关文章:

  • DAO 类的职责与设计原则
  • 绘制动态甘特图(以流水车间调度为例)
  • JWT(JSON Web Token)
  • Spring AI Alibaba 快速开发生成式 Java AI 应用
  • 每日总结4.2
  • 深入理解Python asyncio:从入门到实战,掌握异步编程精髓
  • 为什么你涨不了粉?赚不到技术圈的钱?
  • 教务系统ER图
  • 大模预测法洛四联症的全方位研究报告
  • 特征融合后通道维度增加,卷积层和线性层两种降维方式
  • Ubuntu交叉编译器工具链安装
  • SpringBoot集成OAuth2.0
  • [MySQL初阶]MySQL数据库基础
  • jdk21新特性详解使用总结
  • TypeScript extends 全面解析
  • work02_1 计算这两个日期之间相隔的天数
  • 手机改了IP地址,定位位置会改变吗?
  • Java面试黄金宝典29
  • 蓝桥备赛指南(13):填空签到题(1-1)
  • 车辆控制解决方案
  • 如何通过安当TDE透明加密实现MySQL数据库加密与解密:应用免改造,字段与整库加密全解析
  • MySQL主从复制(四)
  • WEB安全--文件上传漏洞--其他绕过方式
  • OpenLayers:封装Overlay的方法
  • WASM I/O 2025 | MoonBit获Kotlin核心开发,Golem Cloud CEO高度评价
  • 人工智能赋能管理系统,如何实现智能化决策?
  • 操作系统(中断 异常 陷阱) ─── linux第28课
  • 脑影像分析软件推荐 | JuSpace
  • 【kubernetes】pod拉取镜像的策略
  • 关于SQL子查询的使用策略