【科研日常】使用tensorflow实现自注意力机制和交叉注意力机制
自注意力机制代码
import tensorflow as tf
import numpy as np# 定义自注意力层
class Attention_Self(tf.keras.layers.Layer):def __init__(self, units, **kwargs):super(Attention_Self, self).__init__(**kwargs)self.units = units # 特征维度self.wq = tf.keras.layers.Dense(units, use_bias=False) # 不使用偏置self.wk = tf.keras.layers.Dense(units, use_bias=False)self.wv = tf.keras.layers.Dense(units, use_bias=False)def call(self, features):Q = self.wq(features) # 查询K = self.wk(features) # 键V = self.wv(features) # 值matmul_qk = tf.matmul(Q, K, transpose_b=True)dk = tf.cast(tf.shape(K)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)output = tf.matmul(attention_weights, V)return outputdef get_config(self):config = super(Attention_Self, self).get_config()config.update({'units': self.units})return config
自注意力机制测试数据
# 创建模拟输入数据
batch_size = 10
seq_length = 20
feature_dim = 64# 随机生成输入数据
inputs = tf.random.normal((batch_size, seq_length, feature_dim))
inputs.shape
TensorShape([10, 20, 64])
# 实例化自注意力层
attention_layer = Attention_Self(units=feature_dim)
# 应用自注意力层到输入数据
outputs = attention_layer(inputs)
# 打印输出结果
print(outputs.shape)
(10, 20, 64)
写法2
outputs2 = Attention_Self(units=feature_dim)(inputs)
print(outputs2.shape)
(10, 20, 64)
相应解析
这里
get_config
方法首先调用父类tf.keras.layers.Layer
的get_config
方法来继承基础配置信息,然后通过更新字典来加入该层特有的配置信息。在这个例子中,我们添加了units
,这是我们定义的一个关键参数,它指定了内部Dense
层的单元数量。
当使用像tf.keras.models.save_model
和tf.keras.models.load_model
这样的函数来保存和加载模型时,get_config
方法允许 TensorFlow 正确重构你的自定义层。这对于实现复杂的序列化逻辑和确保模型可以跨不同的环境移植是非常重要的。
交叉注意力机制代码
import tensorflow as tfclass CrossAttention(tf.keras.layers.Layer):def __init__(self, units, **kwargs):super(CrossAttention, self).__init__(**kwargs)self.units = units # 特征维度# A 模态的层self.wq_a = tf.keras.layers.Dense(units, use_bias=False)self.wv_a = tf.keras.layers.Dense(units, use_bias=False)# B 模态的层self.wq_b = tf.keras.layers.Dense(units, use_bias=False)self.wv_b = tf.keras.layers.Dense(units, use_bias=False)# 键层是交叉的,A查询B,B查询Aself.wk_a = tf.keras.layers.Dense(units, use_bias=False)self.wk_b = tf.keras.layers.Dense(units, use_bias=False)def call(self, features_a, features_b):# A 查询 BQ_a = self.wq_a(features_a)K_b = self.wk_b(features_b)V_a = self.wv_a(features_a)# B 查询 AQ_b = self.wq_b(features_b)K_a = self.wk_a(features_a)V_b = self.wv_b(features_b)print(self.wq_a.get_weights()[0].shape)print(self.wk_b.get_weights()[0].shape)print(self.wv_a.get_weights()[0].shape)print(Q_a.shape)print(K_b.shape)print(V_a.shape)# 计算注意力分数attention_scores_ab = tf.matmul(Q_a, K_b, transpose_b=True) # 当设置 transpose_b=True 时,这意味着在执行矩阵乘法之前会对第二个输入张量(B)进行转置。attention_scores_ba = tf.matmul(Q_b, K_a, transpose_b=True)print(attention_scores_ab.shape)print(attention_scores_ba.shape)# 缩放dk = tf.cast(tf.shape(K_b)[-1], tf.float32) # cast是类型转换函数scaled_attention_ab = attention_scores_ab / tf.math.sqrt(dk) # sqrt是开方函数scaled_attention_ba = attention_scores_ba / tf.math.sqrt(dk)# softmax 应用到最后一个轴attention_weights_ab = tf.nn.softmax(scaled_attention_ab, axis=-1)attention_weights_ba = tf.nn.softmax(scaled_attention_ba, axis=-1)# 应用注意力权重output_a = tf.matmul(attention_weights_ab, V_a)output_b = tf.matmul(attention_weights_ba, V_b)return output_a, output_bdef get_config(self):config = super(CrossAttention, self).get_config()config.update({'units': self.units})return config
交叉注意力机制测试
# 创建模拟数据
batch_size = 5
seq_length = 10
feature_dim = 64featuresA = tf.random.normal((batch_size, seq_length, feature_dim))
featuresB = tf.random.normal((batch_size, seq_length, feature_dim))# featuresA = tf.random.normal((batch_size, feature_dim))
# featuresB = tf.random.normal((batch_size, feature_dim))# 实例化并应用交叉注意力层
cross_attention = CrossAttention(units=feature_dim)
output_a, output_b = cross_attention(featuresA, featuresB)# 输出结果
print("Output A:", output_a)
print("Output B:", output_b)
(64, 64)
(64, 64)
(64, 64)
(5, 10, 64)
(5, 10, 64)
(5, 10, 64)
(5, 10, 10)
(5, 10, 10)
Output A: tf.Tensor( [[[-5.7970786e-01 -6.0719752e-01 1.9909297e-01 … -3.2587045e-01 4.7336593e-02 3.5937220e-01] [ 2.9775929e-01 9.1934420e-02 -5.2563709e-01 … 7.4571893e-03 -1.0422341e+00 1.5034429e+00] [-5.7887845e-02 -6.6692615e-01 -8.3360665e-02 … 2.1067673e-01 2.0477527e-01 1.7619665e-01] … [ 3.0940935e-01 -2.0744744e-01 -4.0154967e-01 … -1.9383176e-01 -2.4514772e-02 -4.9369448e-01] [-6.2394276e-02 -4.5148364e-01 -7.8599244e-02 … -4.9863800e-02 -1.3933992e-01 3.8461825e-01] [-4.4355023e-01 -6.3579077e-01 -7.1974456e-02 … -1.1960831e-01 -5.1002228e-01 1.1245629e+00]] [[ 6.5766849e-02 -3.0112099e-02 2.3346621e-01 … -2.5223029e-01 7.3245126e-01 -1.3654667e-01]… [ 0.35888967 0.20865138 0.0092712 … -0.01233579 0.3880379 -0.45383584] [-0.18597096 0.44148418 -0.2378662 … 0.0344183 -0.18934116 -0.8733301 ]]], shape=(5, 10, 64), dtype=float32)
1. 查询矩阵 Q_a 的生成
- Q_a 通过
featuresA
和self.wq_a
层的操作得到。 featuresA
的形状为(batch_size, seq_length, feature_dim)
。self.wq_a
是一个 无偏置 的Dense
层,其权重矩阵形状为(feature_dim, units)
,其中units == feature_dim
。- 执行
self.wq_a(featuresA)
时,等价于批量矩阵乘法:对featuresA
中的 每一个序列(对应于每一个 batch 和每个时间步)与self.wq_a
的权重矩阵相乘。
2. 矩阵乘法的细节
featuresA
中的每个(seq_length, feature_dim)
矩阵与self.wq_a
的(feature_dim, feature_dim)
权重矩阵相乘。- 结果是一个新的 (batch_size, seq_length, feature_dim) 形状的张量 Q_a,其中每个元素都是原始特征与权重的线性组合。
- 该操作 不改变序列长度 (seq_length),但对每个时间步的特征进行变换。
3. 操作的目的
- 通过线性变换,将原始特征映射到一个可能有利于后续计算的新空间(此处输入与输出维度相同)。
- 在注意力机制中,这种变换用于准备计算点积或其他形式的相似度,以决定不同特征之间的交互重要性。