当前位置: 首页 > news >正文

rust-candle学习笔记11-实现一个简单的自注意力

参考:about-pytorch

定义ScaledDotProductAttention结构体:

use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: Linear,d_model: Tensor,device: Device,
}

为ScaledDotProductAttention结构体实现new方法:

impl ScaledDotProductAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?, wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?, wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现Module的forward trait:

impl Module for ScaledDotProductAttention {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let q = self.wq.forward(xs)?;let k = self.wk.forward(xs)?;let v = self.wv.forward(xs)?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

融合qkv实现:

定义ScaledDotProductAttentionFusedQKV结构体:

struct ScaledDotProductAttentionFusedQKV {w_qkv: Linear,d_model: Tensor,device: Device,
}

为结构体实现new方法:

impl ScaledDotProductAttentionFusedQKV {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现forward trait:

impl Module for ScaledDotProductAttentionFusedQKV {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let qkv = self.w_qkv.forward(xs)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;// let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;let output = model.forward(&input)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

相关文章:

  • Excel图表 vs 专业可视化工具:差距有多大?内容摘要
  • 浅蓝色调风格人像自拍Lr调色预设,手机滤镜PS+Lightroom预设下载!
  • 【Survival Analysis】【机器学习】【3】deepseek流程图
  • RDD转换算子案例
  • 【Python 字典(Dictionary)】
  • Baklib知识中台引领服务智能跃迁
  • ‌云原生CAE软件
  • Nacos源码—7.Nacos升级gRPC分析四
  • 【C/C++】范围for循环
  • 如何解决按钮重复点击
  • Java高频基础面试题
  • 画家沈燕的山水实验:在传统皴法里植入时代密码
  • Java LocalDateTime类常用时间操作详解
  • 在windows系统中安装图数据库NEO4J
  • 2025年JavaScript性能优化全攻略
  • OrangePi Zero 3学习笔记(Android篇)4 - eudev编译(获取libudev.so)
  • RoPE长度外推:外插内插
  • Microsoft 365 Copilot:为Teams在线会议带来多语言语音交流新体验
  • 内存安全革命:工具、AI 与政策驱动的 C 语言转型之路
  • Mac配置php开发环境(多PHP版本,安装Redis)
  • 五粮液董事长:茅台1935已脱离千元价位带,五粮液在千元价位已逐步摆脱其他竞品纠缠
  • 时隔14个月北京怀柔区重启供地,北京建工以3.59亿元摘得
  • 央视315晚会曝光“保水虾仁”后,湛江4家涉事企业被罚超800万元
  • 从黄土高原到黄浦江畔,澄城樱桃品牌推介会明日在上海举办
  • 央行:将支持资本市场两项货币政策工具的额度合并使用
  • 福特汽车撤回业绩指引,警告关税或造成15亿美元利润损失