当前位置: 首页 > news >正文

rust-candle学习笔记13-实现多头注意力

参考:about-pytorch

定义结构体:

use core::f32;use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{embedding, linear_no_bias, linear, ops, Dropout, Linear, Module, VarBuilder, VarMap};struct MultiHeadAttention {w_qkv: Linear,dropout: Dropout, d_model: Tensor,mask: Tensor,out_proj: Linear,device: Device,out_dim: usize,num_heads: usize,head_dim: usize,
}

定义初始化方法:

impl MultiHeadAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, seq_len: usize, num_heads: usize, drop_p: f32, device: Device) -> Result<Self> {if out_dim % num_heads != 0 {return Err(candle_core::Error::msg("out_dim must be divisible by num_heads"));}Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?, dropout: Dropout::new(drop_p), d_model: Tensor::new(embedding_dim as f32, &device)?, mask: Tensor::tril2(seq_len, DType::U32, &device)?, out_proj: linear(out_dim, out_dim, vb.pp("out_proj"))?, device, out_dim, num_heads, head_dim: out_dim / num_heads, })}
}

定义forward方法:

fn forward(&self, x: &Tensor, train: bool) -> Result<Tensor> {let qkv = self.w_qkv.forward(x)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, self.num_heads, self.head_dim))?;let q = qkv.get_on_dim(2, 0)?;// Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)let q = q.transpose(1, 2)?.contiguous()?;let k = qkv.get_on_dim(2, 0)?;let k = k.transpose(1, 2)?.contiguous()?;let v = qkv.get_on_dim(2, 0)?;let v = v.transpose(1, 2)?.contiguous()?;let attn_scores = q.matmul(&k.transpose(2, 3)?)?;let mask = self.mask.broadcast_as(attn_scores.shape())?;let attn_scores = masked_fill(&attn_scores, &mask, f32::NEG_INFINITY)?;let attn_scores = attn_scores.broadcast_div(&self.d_model.sqrt()?)?;let softmax_dim = attn_scores.rank() - 1;// let attn_weights = ops::softmax_last_dim(&attn_scores)?;  //如果是cpu,可以用这个let attn_weights = ops::softmax(&attn_scores, softmax_dim)?;let attn_weights = self.dropout.forward(&attn_weights, train)?;let attn_output = attn_weights.matmul(&v)?;let attn_output = attn_output.transpose(1, 2)?;let attn_output = attn_output.reshape(&[batch_size, seq_len, self.num_heads*self.head_dim])?;let attn_output = self.out_proj.forward(&attn_output)?;Ok(attn_output)}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;let model = MultiHeadAttention::new(vb.clone(), 3, 4, 6, 2, 0.1, device.clone())?;let output = model.forward(&input, true)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

相关文章:

  • 嵌入式STM32学习——继电器
  • 大模型微调算法原理:从通用到专用的桥梁
  • 解决mybatisplus主键无法自增的问题
  • Spring之AOP
  • Windows中安装nacos-server-2.4.2
  • webpack和vite区别
  • 《Python星球日记》 第52天:反向传播与优化器
  • MySQL事务和JDBC中的事务操作
  • Veins同时打开SUMO和OMNeT++的GUI界面
  • Visual Studio 2022 远程调试
  • C++字符串操作 2024年信息素养大赛复赛 C++小学/初中组 算法创意实践挑战赛 真题详细解析
  • 蓝桥杯嵌入式第十一届省赛真题
  • `RotationTransition` 是 Flutter 中的一个动画组件,用于实现旋转动画效果
  • 仓库管理系统,Java+Vue,含源码及文档,高效管理仓库物资,实现入库、存储、出库全流程数字化精准管控
  • 睿思量化小程序
  • Redis 哨兵
  • AI 入门资源:微软 AI-For-Beginners 项目指南
  • #Redis黑马点评#(四)优惠券秒杀
  • 基于定制开发开源AI智能名片S2B2C商城小程序的公私域流量融合运营策略研究
  • mac u盘重装mac10.15Catalina系统
  • 数说母亲节|全球11亿女性进入更年期,“不是忍忍就好”
  • 种罂粟喂鸡防病?四川广元一村民非法种植毒品原植物被罚​
  • 宇树科技王兴兴:第一桶金来自上海,欢迎上海的年轻人加入
  • 碧桂园:砸锅卖铁保交房、持续推进保主体,尽快让公司恢复正常经营
  • 中国词学研究会原会长、华东师大教授马兴荣逝世,享年101岁
  • 近4小时会谈、3项联合声明、20多份双边合作文本,中俄元首今年首次面对面会晤成果颇丰