当前位置: 首页 > news >正文

rust-candle学习笔记10-使用Embedding

参考:about-pytorch

candle-nn提供embedding()初始化Embedding方法:

pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {let embeddings = vb.get_with_hints((in_size, out_size),"weight",crate::Init::Randn {mean: 0.,stdev: 1.,},)?;Ok(Embedding::new(embeddings, out_size))
}

 candle Embedding初体验:

其中Tokenizer和dataset的构造详情参考:rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本

use candle_nn::{embedding, Embedding, Module, VarBuilder, VarMap};fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let dataset = TokenDataset::new(text, tokenizer, 32, 16, device.clone())?;let (inputs, targets) = dataset.get_item(0)?;println!(" inputs: {:?}\n", inputs);println!(" targets: {:?}\n", targets);let len = dataset.len();println!("{:?}", len);let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding = embedding(vocab_size, 5, vb)?;let x_embedding = embedding.forward(&inputs)?;let y_embedding = embedding.forward(&targets)?;println!(" x_embedding: {:?}\n", x_embedding);println!("{:?}", x_embedding.to_vec2::<f32>()?);println!(" y_embedding: {:?}\n", y_embedding);println!("{:?}", y_embedding.to_vec2::<f32>()?);Ok(())
}

实现正余弦位置编码:

struct PositionEmbedding {pos_embedding: Tensor,device: Device
}
impl PositionEmbedding {fn new(seq_len: usize, embedding_dim: usize, device: Device) -> Result<Self> {if embedding_dim % 2 != 0 {return Err(Box::new(candle_core::Error::msg("embedding_dim must be even")));}let mut pos_embedding_vec: Vec<f32> = Vec::with_capacity(seq_len * embedding_dim);let w_const: f32 = 10000.0;for t in 0..seq_len {let i_max = embedding_dim / 2;for i in 0..i_max {let denominator = w_const.powf(2.0 * i as f32 / embedding_dim as f32);let pos_sin_i = (t as f32 / denominator).sin();let pos_cos_i = (t as f32 / denominator).cos();pos_embedding_vec.push(pos_sin_i);pos_embedding_vec.push(pos_cos_i);}}let pos_embedding = Tensor::from_vec(pos_embedding_vec, (seq_len, embedding_dim), &device)?;Ok(Self { pos_embedding, device })}
}

测试:

注意:candle 不同维度tensor相加直接用+会报错,要显示的调用广播加,高维tensor和低维tensor谁加谁都可以

fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let seq_len = 32;let dataset = TokenDataset::new(text, tokenizer, seq_len, 16, device.clone())?;let batch_size: usize = 6;let mut loader = DataLoader::new(dataset, batch_size, true);loader.reset();let (x, y) = loader.next().unwrap()?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding_dim: usize = 256;let embedding = embedding(vocab_size, embedding_dim, vb)?;let x_embedding = embedding.forward(&x)?;let y_embedding = embedding.forward(&y)?;println!(" x_embedding: {:?}\n", x_embedding);println!(" y_embedding: {:?}\n", y_embedding);let pos_embedding = PositionEmbedding::new(seq_len, embedding_dim, device.clone())?;let pos_emb = pos_embedding.pos_embedding;// candle 不同维度tensor相加直接用+会报错,// 广播加要显示的调用// 下面两种方式都可以let x_input = x_embedding.broadcast_add(&pos_emb)?;// let x_input = pos_emb.broadcast_add(&x_embedding)?;println!(" x_input: {:?}\n", x_input);Ok(())
}

相关文章:

  • QT6(35)4.8定时器QTimer 与QElapsedTimer:理论,例题的界面搭建,与功能的代码实现。
  • 请求从发送到页面渲染的全过程
  • vscode 配置doxygen注释和snippet
  • 大模型备案环节如何评估模型的安全性
  • 简易版无人机飞控
  • C++ Dll创建与调用 查看dll函数 MFC 单对话框应用程序(EXE 工程)改为 DLL 工程
  • Spring Boot快速开发:从零开始搭建一个企业级应用
  • 《工业计算机硬件技术支持手册》适用于哪些人群?
  • STM32GPIO输入实战-key按键easy_button库移植
  • ES6新增Set、Map两种数据结构、WeakMap、WeakSet举例说明详细。(含DeepSeek讲解)
  • Qt开发经验 --- 避坑指南(10)
  • 使用Java实现HTTP协议服务:从自定义服务器到内置工具
  • MySQL 8.0(主从复制)
  • 如何删除豆包本地大模型
  • 操纵杆支架加工工艺及钻3φ11孔夹具设计
  • L48.【LeetCode题解】904. 水果成篮
  • 《P1177 【模板】排序》
  • 高质量老年生活:从主动健康管理到预防医学的社会价值
  • 一种安全不泄漏、高效、免费的自动化脚本平台
  • C++学习-入门到精通-【5】类模板array和vector、异常捕获
  • 宇树科技王兴兴:第一桶金来自上海,欢迎上海的年轻人加入
  • 泰特现代美术馆25年:那些瞬间,让艺术面向所有人
  • 人民日报钟声:平等对话是解决大国间问题的正确之道
  • 独家丨刘家琨获普利兹克奖感言:守护原始的感悟力
  • “上海之帆”巡展在日本大阪开幕,松江区组织企业集体出展
  • 深入贯彻中央八项规定精神学习教育中央第一指导组指导督导河北省见面会召开