当前位置: 首页 > news >正文

实现pytorch注意力机制-one demo

主要组成部分:

1. 定义注意力层

定义一个Attention_Layer类,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。

2. 定义前向传播:

定义了注意力层的前向传播过程,包括计算注意力权重和输出。

3. 数据准备

生成一个随机的数据集,包含3个句子,每个句子10个词,每个词128个特征。

4. 实例化注意力层:

实例化一个注意力层,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。

5. 前向传播

将数据传递给注意力层的前向传播方法。

6. 分析结果

获取第一个句子的注意力权重。

7. 可视化注意力权重

使用matplotlib库可视化了注意力权重。

**主要函数和类:**
Attention_Layer类:定义了注意力层的结构和前向传播过程。
forward方法:定义了注意力层的前向传播过程。
torch.from_numpy函数:将numpy数组转换为PyTorch张量。
matplotlib库:用于可视化注意力权重。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# 定义注意力层
class Attention_Layer(nn.Module):
    def __init__(self, hidden_dim, is_bi_rnn):
        super(Attention_Layer,self).__init__()
        self.hidden_dim = hidden_dim
        self.is_bi_rnn = is_bi_rnn
        if is_bi_rnn:
            self.Q_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
            self.K_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
            self.V_linear = nn.Linear(hidden_dim * 2, hidden_dim * 2, bias = False)
        else:
            self.Q_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
            self.K_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
            self.V_linear = nn.Linear(hidden_dim, hidden_dim, bias = False)
        
    def forward(self, inputs, lens):
        # 获取输入的大小
        size = inputs.size()
        Q = self.Q_linear(inputs) 
        K = self.K_linear(inputs).permute(0, 2, 1)
        V = self.V_linear(inputs)
        max_len = max(lens)
        sentence_lengths = torch.Tensor(lens)
        mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]
        mask = mask.unsqueeze(dim = 1)
        mask = mask.expand(size[0], max_len, max_len)
        padding_num = torch.ones_like(mask)
        padding_num = -2**31 * padding_num.float()
        alpha = torch.matmul(Q, K)
        alpha = torch.where(mask, alpha, padding_num)
        alpha = F.softmax(alpha, dim = 2)
        out = torch.matmul(alpha, V)
        return out

# 准备数据
data = np.random.rand(3, 10, 128)  # 3个句子,每个句子10个词,每个词128个特征
lens = [7, 10, 4]  # 每个句子的长度

# 实例化注意力层
hidden_dim = 64
is_bi_rnn = True
att_L = Attention_Layer(hidden_dim, is_bi_rnn)

# 前向传播
att_out = att_L(torch.from_numpy(data).float(), lens)

# 分析结果
attention_weights = att_out[0, :, :].detach().numpy()  # 获取第一个句子的注意力权重

# 可视化注意力权重
plt.imshow(attention_weights, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.show()

在这里插入图片描述

相关文章:

  • 业务干挂数据库,Oracle内存分配不足
  • css:position
  • K8s之存储卷
  • Prompt通用技巧
  • redis sentinel模式 与 redis 分片集群 配置
  • (五)Spring Boot学习——spring security +jwt使用(前后端分离模式)
  • iOS实现生物识别
  • git: 如何查询某个文件或者某个目录的更新历史
  • 服务器之连接简介(Detailed Explanation of Server Connection)
  • 网络编程01 - 速通计网知识点
  • python学习第十四天之机器学习名词介绍
  • RNN复兴!性能反超Transformer,训练速度提升1300倍!
  • 数据结构 栈和队列
  • 本地部署DeepSeek + AnythingLLM 搭建高效安全的个人知识库
  • 突破数据壁垒,动态住宅代理IP在数据采集中的高效应用
  • 系统思考—团队学习
  • SpringBoot的单机模式是否需要消息队列?分布式应用中消息队列如何和服务的发现与注册、配置中心、SpringMVC相配合
  • 有哪些免费的SEO软件优化工具
  • AGI时代的认知重塑:人类文明的范式转移与思维革命
  • Python多进程Logging
  • 网站建设套餐内容/双11销售数据
  • .net企业网站/怎么制作自己公司网站
  • 燕郊医院网站建设/原创代写文章平台
  • 本地用织梦做网站/网络营销软件下载
  • 外卖网站建设的策划方案/下载地图导航手机版免流量费用
  • 外发加工什么最好/关键词优化报价查询