当前位置: 首页 > news >正文

如何用熵正则化控制注意力分数的分布

先写一个CrossAttention模块,

# input: Q(B, L, d), KV(B, N, d)
# output: (B, L, dim)
# 0<alpha<=ln(N), alpha越接近0, 注意力分数越逼近one-hot分布
class CrossAttention(layers.Layer):def __init__(self, num_head, dim, alpha,**kwargs):super().__init__(**kwargs)self.alpha = alphaself.num_head = num_headself.dim = dimself.layernorm = layers.LayerNormalization()def build(self, input_shape):self.qdk = self.add_weight(name='query_dense_kernel', shape=[input_shape[0][-1], self.num_head, self.dim])self.kdk = self.add_weight(name='key_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])self.vdk = self.add_weight(name='value_dense_kernel', shape=[input_shape[1][-1], self.num_head, self.dim])self.odk = self.add_weight(name='output_dense_kernel', shape=[self.num_head, self.dim, self.dim])self.odb = self.add_weight(name='output_dense_bais', shape=[self.dim])def call(self, inputs, *args, **kwargs):Q, KV = inputsquery = tf.einsum("abc, cde->abde", Q, self.qdk)key = tf.einsum("abc, cde->abde", KV, self.kdk)value = tf.einsum("abc, cde->abde", KV, self.vdk)query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self.dim)))attention_scorces = tf.math.softmax(tf.einsum("abcd, aecd->acbe", query, key))self.add_loss(tf.reduce_mean((-tf.reduce_sum(attention_scorces * tf.math.log(attention_scorces + 1e-07), axis=-1) - self.alpha)**2))attention_output = tf.einsum("abcd, aceb->aecd", value, attention_scorces)output = tf.einsum("abcd, cdd->abd", attention_output, self.odk) + self.odbreturn self.layernorm(output + Q), attention_scorces

损失函数包含两种类型:prediction loss和regularization loss。

regularization loss需要add_loss方法进行添加,add_loss方法添加的损失值可以通过model.losses进行访问,返回一个集合,集合每个元素对应一个正则损失。

regularization loss被add_loss方法添加后,需要被tf.GradientTape()的作用域包含

def fit(x, y, epochs, model):optimizer = tf.keras.optimizers.Adam()for epoch in range(epochs):print("\nStart of epoch %d" % (epoch,))with tf.GradientTape() as tape:logits = model(x, training=True)[0]# Compute the loss value for this minibatch.loss_value = tf.keras.losses.binary_crossentropy(y, logits)print(model.losses)loss_value += sum(model.losses)grads = tape.gradient(loss_value, model.trainable_weights)# Run one step of gradient descent by updating# the value of the variables to minimize the loss.optimizer.apply_gradients(zip(grads, model.trainable_weights))print("attention scores entropy loss: %s" % (sum(model.losses)))print("loss" % loss_value)

接下来,设置一个简单的任务和数据看看熵正则化的效果,

class model(tf.keras.Model):def __init__(self):super().__init__()self.CA = CrossAttention(3, 16, 0.2)def build(self, input_shape):self.k = self.add_weight(name="predict_kernel", shape=[input_shape[0][-2], 16, 2])def call(self, inputs, *args, **kwargs):x, scores = self.CA(inputs)return tf.math.sigmoid(tf.einsum("abc, bcd->ad", x, self.k)), scoresif __name__ == '__main__':Q = tf.random.uniform((1, 4, 16))KV = tf.random.uniform((1, 6, 16))labels = tf.constant([[0., 1]])model = model()fit((Q, KV), labels, epochs=1000, model=model)print(model((Q, KV)))

模型训练好后,打印注意力分数的分布情况,可以发现每一行注意力分数都接近one-hot分布。

一个概率分布的信息熵最小值为0,最大值为logk。最小值对应熵最小的one-hot分布,最大值对应熵最大的均匀分布。在这里设置的熵正则化损失函数为(Entropy(scorces)-alpha)^2,通过调整alpha的大小,可以控制注意力分数逼近one-hot分布的程度。

<tf.Tensor: shape=(1, 3, 4, 6), dtype=float32, numpy=
array([[[[3.27099487e-03, 2.51237601e-02, 9.59103048e-01, 6.17671618e-03, 4.41661570e-03, 1.90874794e-03],[3.13497148e-03, 2.45431308e-02, 9.60317731e-01, 5.98660251e-03, 4.21733223e-03, 1.80015271e-03],[3.40689556e-03, 2.57629622e-02, 9.57871556e-01, 6.44017057e-03, 4.52079810e-03, 1.99752697e-03],[1.10661890e-03, 1.27607975e-02, 9.81542170e-01, 2.43383530e-03, 1.57478068e-03, 5.81745524e-04]],[[3.10348310e-02, 8.87579299e-05, 4.77730093e-04, 1.46521628e-03, 9.52835977e-01, 1.40975416e-02],[2.63910089e-02, 5.77273713e-05, 3.39534425e-04, 1.06126023e-03, 9.60522711e-01, 1.16277682e-02],[3.12446002e-02, 9.48462111e-05, 5.04475611e-04, 1.46811537e-03, 9.52523112e-01, 1.41648324e-02],[1.50452955e-02, 1.21340790e-05, 9.36727956e-05, 3.67841218e-04, 9.78706181e-01, 5.77491429e-03]],[[2.43717595e-03, 2.99123675e-02, 9.57738578e-01, 8.62001721e-03, 6.63190091e-04, 6.28605427e-04],[2.37493636e-03, 2.91979928e-02, 9.58939075e-01, 8.24017916e-03, 6.44378248e-04, 6.03430963e-04],[2.84763589e-03, 3.26210111e-02, 9.53170657e-01, 9.78391431e-03, 8.07495147e-04, 7.69288861e-04],[1.05220091e-03, 1.88548081e-02, 9.75066125e-01, 4.56135161e-03, 2.38224253e-04, 2.27281591e-04]]]],dtype=float32)>

http://www.dtcms.com/a/359979.html

相关文章:

  • 让你的App与众不同打造独特品牌展示平台
  • Scikit-learn Python机器学习 - 类别特征提取- OneHotEncoder
  • 编写Linux下usb设备驱动方法:disconnect函数中要完成的任务
  • 【数学建模学习笔记】异常值处理
  • RAG(检索增强生成)技术的核心原理与实现细节
  • 【Unity开发】Unity核心学习(三)
  • macos自动安装emsdk4.0.13脚本
  • 在Ubuntu系统上安装和配置JMeter和Ant进行性能测试
  • 基于SpringBoot + Vue 的宠物领养管理系统
  • 【Spring Cloud微服务】7.拆解分布式事务与CAP理论:从理论到实践,打造数据一致性堡垒
  • ANR InputDispatching TimeOut超时判断 - android-15.0.0_r23
  • 拆分TypeScript项目的学习收获:处理编译缓存和包缓存,引用本地项目,使用相对路径
  • 配置 Kubernetes Master 节点不可调度的标准方法
  • 【51单片机】【protues仿真】基于51单片机音乐喷泉系统
  • 记录测试环境hertzbeat压测cpu高,oom问题排查。jvm,mat,visulavm
  • opencv 梯度提取
  • [Android] UI进阶笔记:从 Toolbar 到可折叠标题栏的完整实战
  • 掩码语言模型(Masked Language Model, MLM)
  • android-studio 安装
  • 基于计算机视觉的海底图像增强系统:技术详述与实现
  • 如何正确校正电脑时间?
  • 【开源】AI模型接口管理与分发系统开源项目推荐
  • Redis八股小记
  • 人工智能学习:机器学习相关面试题(二)
  • 【开题答辩全过程】以 基于vue+springboot的校园疫情管理系统的设计与实现为例,包含答辩的问题和答案
  • 企业级开发模型:从软件生命周期到 Git 分支管理
  • 【C++ 】string类:深拷贝与浅拷贝解析
  • DSPFilters实现低通滤波器(QT)
  • 电力电子技术知识学习-----晶闸管
  • 前端组件拆分与管理实战:如何避免 props 地狱,写出高可维护的项目