当前位置: 首页 > news >正文

【科研日常】使用tensorflow实现自注意力机制和交叉注意力机制

自注意力机制代码

import tensorflow as tf
import numpy as np# 定义自注意力层
class Attention_Self(tf.keras.layers.Layer):def __init__(self, units, **kwargs):super(Attention_Self, self).__init__(**kwargs)self.units = units  # 特征维度self.wq = tf.keras.layers.Dense(units, use_bias=False) # 不使用偏置self.wk = tf.keras.layers.Dense(units, use_bias=False)self.wv = tf.keras.layers.Dense(units, use_bias=False)def call(self, features):Q = self.wq(features)  # 查询K = self.wk(features)  # 键V = self.wv(features)  # 值matmul_qk = tf.matmul(Q, K, transpose_b=True)dk = tf.cast(tf.shape(K)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)output = tf.matmul(attention_weights, V)return outputdef get_config(self):config = super(Attention_Self, self).get_config()config.update({'units': self.units})return config

自注意力机制测试数据

# 创建模拟输入数据
batch_size = 10
seq_length = 20
feature_dim = 64# 随机生成输入数据
inputs = tf.random.normal((batch_size, seq_length, feature_dim))
inputs.shape

TensorShape([10, 20, 64])

# 实例化自注意力层
attention_layer = Attention_Self(units=feature_dim)
# 应用自注意力层到输入数据
outputs = attention_layer(inputs)
# 打印输出结果
print(outputs.shape)

(10, 20, 64)

写法2

outputs2 = Attention_Self(units=feature_dim)(inputs)
print(outputs2.shape)

(10, 20, 64)

相应解析

这里 get_config 方法首先调用父类 tf.keras.layers.Layerget_config 方法来继承基础配置信息,然后通过更新字典来加入该层特有的配置信息。在这个例子中,我们添加了 units,这是我们定义的一个关键参数,它指定了内部 Dense 层的单元数量。
当使用像 tf.keras.models.save_modeltf.keras.models.load_model 这样的函数来保存和加载模型时,get_config 方法允许 TensorFlow 正确重构你的自定义层。这对于实现复杂的序列化逻辑和确保模型可以跨不同的环境移植是非常重要的。

交叉注意力机制代码

import tensorflow as tfclass CrossAttention(tf.keras.layers.Layer):def __init__(self, units, **kwargs):super(CrossAttention, self).__init__(**kwargs)self.units = units  # 特征维度# A 模态的层self.wq_a = tf.keras.layers.Dense(units, use_bias=False)self.wv_a = tf.keras.layers.Dense(units, use_bias=False)# B 模态的层self.wq_b = tf.keras.layers.Dense(units, use_bias=False)self.wv_b = tf.keras.layers.Dense(units, use_bias=False)# 键层是交叉的,A查询B,B查询Aself.wk_a = tf.keras.layers.Dense(units, use_bias=False)self.wk_b = tf.keras.layers.Dense(units, use_bias=False)def call(self, features_a, features_b):# A 查询 BQ_a = self.wq_a(features_a)K_b = self.wk_b(features_b)V_a = self.wv_a(features_a)# B 查询 AQ_b = self.wq_b(features_b)K_a = self.wk_a(features_a)V_b = self.wv_b(features_b)print(self.wq_a.get_weights()[0].shape)print(self.wk_b.get_weights()[0].shape)print(self.wv_a.get_weights()[0].shape)print(Q_a.shape)print(K_b.shape)print(V_a.shape)# 计算注意力分数attention_scores_ab = tf.matmul(Q_a, K_b, transpose_b=True) # 当设置 transpose_b=True 时,这意味着在执行矩阵乘法之前会对第二个输入张量(B)进行转置。attention_scores_ba = tf.matmul(Q_b, K_a, transpose_b=True)print(attention_scores_ab.shape)print(attention_scores_ba.shape)# 缩放dk = tf.cast(tf.shape(K_b)[-1], tf.float32) # cast是类型转换函数scaled_attention_ab = attention_scores_ab / tf.math.sqrt(dk) # sqrt是开方函数scaled_attention_ba = attention_scores_ba / tf.math.sqrt(dk)# softmax 应用到最后一个轴attention_weights_ab = tf.nn.softmax(scaled_attention_ab, axis=-1)attention_weights_ba = tf.nn.softmax(scaled_attention_ba, axis=-1)# 应用注意力权重output_a = tf.matmul(attention_weights_ab, V_a)output_b = tf.matmul(attention_weights_ba, V_b)return output_a, output_bdef get_config(self):config = super(CrossAttention, self).get_config()config.update({'units': self.units})return config

交叉注意力机制测试

# 创建模拟数据
batch_size = 5
seq_length = 10
feature_dim = 64featuresA = tf.random.normal((batch_size, seq_length, feature_dim))
featuresB = tf.random.normal((batch_size, seq_length, feature_dim))# featuresA = tf.random.normal((batch_size, feature_dim))
# featuresB = tf.random.normal((batch_size, feature_dim))# 实例化并应用交叉注意力层
cross_attention = CrossAttention(units=feature_dim)
output_a, output_b = cross_attention(featuresA, featuresB)# 输出结果
print("Output A:", output_a)
print("Output B:", output_b)

(64, 64)
(64, 64)
(64, 64)
(5, 10, 64)
(5, 10, 64)
(5, 10, 64)
(5, 10, 10)
(5, 10, 10)
Output A: tf.Tensor( [[[-5.7970786e-01 -6.0719752e-01 1.9909297e-01 … -3.2587045e-01 4.7336593e-02 3.5937220e-01] [ 2.9775929e-01 9.1934420e-02 -5.2563709e-01 … 7.4571893e-03 -1.0422341e+00 1.5034429e+00] [-5.7887845e-02 -6.6692615e-01 -8.3360665e-02 … 2.1067673e-01 2.0477527e-01 1.7619665e-01] … [ 3.0940935e-01 -2.0744744e-01 -4.0154967e-01 … -1.9383176e-01 -2.4514772e-02 -4.9369448e-01] [-6.2394276e-02 -4.5148364e-01 -7.8599244e-02 … -4.9863800e-02 -1.3933992e-01 3.8461825e-01] [-4.4355023e-01 -6.3579077e-01 -7.1974456e-02 … -1.1960831e-01 -5.1002228e-01 1.1245629e+00]] [[ 6.5766849e-02 -3.0112099e-02 2.3346621e-01 … -2.5223029e-01 7.3245126e-01 -1.3654667e-01]… [ 0.35888967 0.20865138 0.0092712 … -0.01233579 0.3880379 -0.45383584] [-0.18597096 0.44148418 -0.2378662 … 0.0344183 -0.18934116 -0.8733301 ]]], shape=(5, 10, 64), dtype=float32)

1. 查询矩阵 Q_a 的生成

  • Q_a 通过 featuresAself.wq_a 层的操作得到。
  • featuresA 的形状为 (batch_size, seq_length, feature_dim)
  • self.wq_a 是一个 无偏置Dense 层,其权重矩阵形状为 (feature_dim, units),其中 units == feature_dim
  • 执行 self.wq_a(featuresA) 时,等价于批量矩阵乘法:对 featuresA 中的 每一个序列(对应于每一个 batch 和每个时间步)与 self.wq_a 的权重矩阵相乘。

2. 矩阵乘法的细节

  • featuresA 中的每个 (seq_length, feature_dim) 矩阵与 self.wq_a(feature_dim, feature_dim) 权重矩阵相乘。
  • 结果是一个新的 (batch_size, seq_length, feature_dim) 形状的张量 Q_a,其中每个元素都是原始特征与权重的线性组合。
  • 该操作 不改变序列长度 (seq_length),但对每个时间步的特征进行变换。

3. 操作的目的

  • 通过线性变换,将原始特征映射到一个可能有利于后续计算的新空间(此处输入与输出维度相同)。
  • 在注意力机制中,这种变换用于准备计算点积或其他形式的相似度,以决定不同特征之间的交互重要性。
http://www.dtcms.com/a/331577.html

相关文章:

  • Java中Record的应用
  • Flink Stream API 源码走读 - socketTextStream
  • Spark Shuffle机制原理
  • STM32HAL 快速入门(七):GPIO 输入之光敏传感器控制蜂鸣器
  • 深入理解管道(下):括号命令 `()`、`-ExpandProperty` 与 AD/CSV 实战
  • Java 大视界 -- Java 大数据在智能家居能耗监测与节能优化中的应用探索(396)
  • 【漏洞复现】WinRAR 目录穿越漏洞(CVE-2025-8088)
  • JavaScript 解构赋值语法详解
  • iOS Sqlite3
  • Playwright初学指南 (3):深入解析交互操作
  • 【完整源码+数据集+部署教程】肾脏病变实例分割系统源码和数据集:改进yolo11-CARAFE
  • 基于机器学习的文本情感极性分析系统设计与实现
  • 华为宣布云晰柔光屏技术迎来重大升级
  • 生产环境sudo配置详细指南
  • 机器学习学习总结
  • 如何选择适合工业场景的物联网网关?
  • 相较于传统AR作战环境虚拟仿真系统,其优势体现在哪些方面?
  • Python小程序1.0版本
  • C++类与对象核心知识点全解析(中)【六大默认成员函数详解】
  • Perforce P4 Git 连接器
  • 随身 Linux 开发环境:使用 cpolar 内网穿透服务实现 VSCode 远程访问
  • Activity + fragment的页面结构,fragment始终无法显示问题
  • AI 赋能的软件工程全生命周期应用
  • 第16届蓝桥杯C++中高级选拔赛(STEMA)2024年10月20日真题
  • 【C#】PNG 和 JPG、JPEG的应用以及三种格式的区别?
  • Oracle commit之后做了什么
  • 【20-模型诊断调优】
  • BSCI认证对企业的影响,BSCI认证的重要性,BSCI审核的核心内容
  • 信息vs知识:人类学习与AI规则提取
  • 设计模式笔记_行为型_状态模式