当前位置: 首页 > news >正文

[随笔] nn.Embedding的前向传播与反向传播

nn.Embedding的前向传播与反向传播

nn.Embedding的前向计算过程

embedding module 的前向过程其实是一个索引(查表)的过程
表的形式是一个 matrix(embedding.weight, learnable parameters)

matrix.shape: (v, h)
v:vocabulary size=num_embedding
h:hidden dimension=embedding_dim

仅从数学的角度来说(方便推导模型),具体索引的过程,可以通过 one hot + 矩阵乘法的形式实现的

input.shape: (b, s)
> b:batch size
> s:seq len

当执行下行代码时,会进行如下计算 
 embed = embedding(input) 

> input.shape(b,s) 	e.g [[0, 2, 2,1]]
> 最终的维度变化情况:(b, s) ==> (b, s, h)

1.(b, s) 经过 one hot => (b, s, v)
inputs: [[0, 2, 2, 1 , 1]] 
inputs One-Hot: 数值分类(0-4 => 五分类) 0:[1,0,0,0,0]
    [[[1,0,0,0,0],
    [0,0,1,0,0],
    [0,0,1,0,0],
    [0,1,0,0,0],
    [0,1,0,0,0]]]
matrix(embedding.weight):
    [[ 1.0934,  1.7521, -1.9529, -1.0145,  0.5770],
    [-0.4371, -0.4270, -0.4908, -0.3988,  0.9695],
    [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
    [ 0.7268, -0.4491, -0.8089,  0.7516,  1.2716],
    [ 0.7785, -0.4336, -0.7542, -0.1953,  0.9711]]

2.(b, s, v) @ (v, h) ==> (b, s, h)

x(b, s, h):
    [[[ 1.0934,  1.7521, -1.9529, -1.0145,  0.5770],
    [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
    [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
    [-0.4371, -0.4270, -0.4908, -0.3988,  0.9695],
    [-0.4371, -0.4270, -0.4908, -0.3988,  0.9695]]]

本质上,embedding(input) 是一个内存寻址的过程:

假设 inputs 是一个包含词索引的张量,i是数值化文本 inputs上的Token,weight 是嵌入矩阵。
	对于每个索引 i,嵌入向量 v_i 对应的计算过程是:
		v_i=embedding.weight[i]
nn.Embedding的反向传播过程

只有前向传播中用到的索引会接收梯度。

假如反向传播过来的梯度是 [0.1,0.1,0.3] ,原始的embedding矩阵= [[1. ,1. ,1.],[1. ,1. ,1.]] , lr=0.1

那么 反向传播以后embedding的参数就为 [[1. ,1. ,1.],[1. ,1. ,1.]] - 1 * [[0.1,0.1,0.3],[0.,0.,0.]]

即 [[0.99. ,0.99 ,0.97],[1. ,1. ,1.]]

相关文章:

  • Spring Boot项目中结合MyBatis实现MySQL的自动主从切换
  • 快排算法 (分治实现)
  • 11. Langchain输出解析(Output Parsers):从自由文本到结构化数据
  • 【后端开发】Spring MVC-常见使用、Cookie、Session
  • 分析下HashMap容量和负载系数,它是怎么扩容的?
  • 底盘---全向轮(Omni Wheel)
  • 重温Java - Java基础二
  • 无人设备遥控器之通信链路管理篇
  • C++ 创建静态数组出现栈满程序崩溃的问题
  • 【虚拟机栈中的栈帧是什么?有什么作用?局部变量表、操作数栈、动态链接和方法返回地址是什么?有什么作用?为什么要放在栈帧里?】
  • Ubuntu24.04 编译 Qt 源码
  • 一个可以在Android手机上运行的Linux高仿window10的应用
  • Python中的AdaBoost分类器:集成方法与模型构建
  • VT01N/VT02N进行交货的时候,对装运点加权限控制的增强
  • 原生SSE实现AI智能问答+Vue3前端打字机流效果
  • 【语法】C++的list
  • 模糊测试究竟在干什么
  • 41、web前端开发之Vue3保姆教程(五 实战案例)
  • 结合大语言模型整理叙述并生成思维导图的思路
  • C语言--常用的链表操作
  • 网站开发ceac证/网站流量统计分析工具
  • 拉萨网站建设价格/公关公司提供的服务有哪些
  • 河南省汝州市建设网站/百度数据研究中心官网
  • 济南正规网站建设公司哪家好/百度推广登陆平台
  • 威海环翠疫情最新消息/seo一般包括哪些内容
  • 南山做棋牌网站建设/seo基础知识培训