当前位置: 首页 > news >正文

深度学习中的三种Embedding技术详解

提纲

    • 背景介绍
    • 特征类型与Embedding方法
    • 1. ID类特征的Embedding处理
      • 1.1 标准Embedding方法
      • 1.2 IdHashEmbedding方法
    • 2. 数值型特征的Embedding处理
      • 2.1 RawEmbedding方法
    • 三种Embedding方法对比总结
    • 实践建议
    • 总结

背景介绍

在深度学习领域,Embedding(嵌入)技术是一种将高维稀疏数据转换为低维稠密向量表示的核心方法。它在推荐系统、自然语言处理、图像识别等多个领域中发挥着重要作用。

Embedding的主要目的是:

  1. 将离散的类别特征(如用户ID、商品类别)转换为连续的向量表示
  2. 在低维空间中捕获特征间的语义关系
  3. 提高模型的表达能力和泛化性能

正确使用Embedding技术对于模型性能至关重要。本文将详细解析三种不同的Embedding方法及其应用场景。

特征类型与Embedding方法

在实际项目中,我们通常会遇到两种主要的特征类型:

  1. ID类特征:如用户ID、骑手ID、商品类别等离散的标识符
  2. 数值型特征:如用户年龄、订单金额等连续数值

针对不同类型的特征,我们需要采用不同的Embedding策略来处理。

1. ID类特征的Embedding处理

1.1 标准Embedding方法

适用场景:适用于中低基数类别特征,如商品类别、用户性别、服务类型、地区编码等。

工作原理

  • 为每个唯一的特征值维护一个独立的向量表示
  • 直接使用特征值作为索引查找Embedding table中对应的Embedding向量
  • Embedding表的大小(行数)需要大于等于特征最大的特征值

Example
假设我们有一个item_category特征,有50个不同的类别,我们希望将其映射到16维的向量空间:

import torch
import torch.nn as nn
# 特征配置
feature_config = {'item_category': {'hash_bucket_size': 50, 'embedding_dimension': 16}
}# 创建Embedding层(table),维度为50X16的矩阵
embedding_layer = nn.Embedding(num_embeddings=50,     # 类别数量embedding_dim=16       # 目标维度
)# 输入样本:两个订单的类别分别为5和23
input_tensor = torch.tensor([5, 23])  # shape: [2]# Embedding过程就时vlookup过程,从embedding_layer 矩阵中选取第5行和第23行数据向量
embedded = embedding_layer(input_tensor)  # shape: [2, 16]print(f"输入形状: {input_tensor.shape}")    # [2]
print(f"输出形状: {embedded.shape}")       # [2, 16]

初始化方式

- PyTorch中nn.Embedding的默认初始化
- 权重矩阵形状为[50, 16],每个元素从标准正态分布N(0,1)中采样,weight[i, j] ~ N(0, 1) for all i in [0, 49] and j in [0, 15]

优缺点分析

  • 优点:无哈希冲突,每个特征值有独立表示,易于解释
  • 缺点:对于高基数特征内存消耗巨大,无法处理训练时未见过的特征值大于embedding table的情况,item_category例中,假设有一个item_category的数值为51,则运行时会报突破边界错,因为从embedding table中找不到第51行。

1.2 IdHashEmbedding方法

适用场景:适用于高基数类别特征,如用户ID、商品ID、骑手ID等具有大量唯一值的特征。

为什么需要IdHashEmbedding?
当面对高基数特征时,标准Embedding方法会遇到以下问题:

  1. 内存爆炸:如果用户ID有千万/亿级别,直接Embedding需要巨大的内存
  2. 未见特征值:遇到特征值大于embedding table行数时报错
  3. 训练效率低:大量参数需要更新,训练速度慢

工作原理

  • 通过哈希函数将原始特征值映射到固定大小的桶中
  • 使用哈希后的值作为索引从Embedding表中查找向量
  • 内存使用固定,不受原始特征基数影响

Example

class IdHashEmbedding(nn.Module):...def _build_embeddings(self):for feat_name, config in self.feature_config.items():self.embeddings[feat_name] = nn.Embedding(num_embeddings=config['hash_bucket_size'],embedding_dim=config['embedding_dimension'])def forward(self, features):embed_list = []for feat_name, config in self.feature_config.items():id_x = features[feat_name]hash_x = id_x.cpu().apply_(lambda x: hash_bucket(x, config['hash_bucket_size'])).long().to(self.device)     embed_list.append(self.embeddings[feat_name](hash_x))embeds = torch.cat(embed_list, dim=1)return embeds# 用户ID特征,原始ID可能达到千万级别
user_ids = torch.tensor([123456, 789012])# 使用IdHashEmbedding压缩到1000个桶中
id_hash_embedding = IdHashEmbedding({'user_id': {'hash_bucket_size': 1000, 'embedding_dimension': 32}
}, device)# 计算Id特征对应的IdHash值 
hash_bucket(123456, 1000) = 123456 % 1000 = 456, 
hash_bucket(789012, 1000) = 789012 % 1000 = 12
# 实际查找的是索引456和12,而非原始ID
embedded = id_hash_embedding({'user_id': user_ids})

优缺点分析

  • 优点:内存高效,可以处理任意基数的特征,能处理未见过的特征值
  • 缺点:可能产生哈希冲突,不同特征值可能映射到同一向量,比如对用户223456,其IdHash值也是456,和用户ID123456对应的Id Hash值相同,在模型训练时会当作相同的特征

2. 数值型特征的Embedding处理

2.1 RawEmbedding方法

适用场景:适用于连续数值特征,如用户年龄、注册时长、订单金额等。

工作原理

  • 使用线性变换将原始数值映射到目标维度
  • 相比查找表方式,更适合处理连续值

Example

class RawEmbedding(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.input_dim = input_dimself.linear = nn.Linear(input_dim, output_dim)def forward(self, x):return self.linear(x)# 用户年龄特征
user_ages = torch.tensor([[25.0], [35.0]])  # shape: [2, 1]# 使用RawEmbedding将1维年龄映射到8维向量
raw_embedding = RawEmbedding(input_dim=1, output_dim=8)
embedded = raw_embedding(user_ages)  # shape: [2, 8]print(f"输入形状: {user_ages.shape}")   # [2, 1]
print(f"输出形状: {embedded.shape}")     # [2, 8]

优缺点分析

  • 优点:适合处理连续特征,计算简单高效,参数量少
  • 缺点:不适合处理未编码的类别特征,表达能力有限

三种Embedding方法对比总结

特征类型Embedding方法适用场景内存消耗处理未见特征哈希冲突表达能力计算复杂度
ID类特征Embedding中低基数类别特征(商品类别、服务类型等)
ID类特征IdHashEmbedding高基数类别特征(用户ID、骑手ID等)
数值特征RawEmbedding连续数值特征(年龄、服务时长等)

实践建议

在骑手等级转换预测项目中,我们建议:

  1. 高基数ID特征(如骑手ID、用户ID)使用IdHashEmbedding
  2. 中低基数ID特征(如服务类型、地区编码)使用Embedding
  3. 连续数值特征(如年龄、注册时长)使用RawEmbedding

合理选择和使用Embedding方法,可以显著提升模型性能,同时控制内存消耗。

总结

Embedding技术是深度学习模型中不可或缺的组成部分。通过合理选择IdHashEmbedding、Embedding和RawEmbedding,我们可以有效处理不同类型的特征,构建高性能的预测模型。理解每种方法的原理和适用场景,对于构建成功的机器学习系统具有重要意义。

在实际应用中,我们需要根据特征的类型、基数大小、内存限制以及业务需求来选择合适的Embedding方法,这样才能在模型效果和计算效率之间取得最佳平衡。

http://www.dtcms.com/a/312476.html

相关文章:

  • OSPF知识点整理
  • [Oracle] 获取系统当前日期
  • ABP VNext + Quartz.NET vs Hangfire:灵活调度与任务管理
  • 35.【.NET8 实战--孢子记账--从单体到微服务--转向微服务】--数据缓存
  • Petalinux 23.2 构建过程中常见下载错误及解决方法总结
  • 【从零开始学习Redis】初识Redis
  • Android 之 常用布局
  • OpenWrt | 如何在 ucode 脚本中打印日志
  • 评测PHOCR中文文本识别模型
  • MySQL半同步复制机制详解:AFTER_SYNC vs AFTER_COMMIT 的优劣与选择
  • Python 程序设计讲义(57):Python 的函数——可变参数的使用
  • 专网内网IP攻击防御:从应急响应到架构加固
  • 老电脑PE下无法读取硬盘的原因
  • 【LeetCode刷题指南】--二叉树的后序遍历,二叉树遍历
  • 7.14.散列表的基本概念(散列表又名哈希表,Hash Table)
  • 01.Redis 概述
  • 嵌入式通信协议解析(基于红外NEC通信协议)
  • 旧笔记本电脑如何安装飞牛OS
  • 前端工程化:npmvite
  • 解剖 .NET 经典:从 Component 到 BackgroundWorker
  • python基础语法6,简单文件操作(简单易上手的python语法教学)(课后习题)
  • Jetpack Compose for XR:构建下一代空间UI的完整指南
  • Hyper-V + Centos stream 9 搭建K8s集群(二)
  • MySQL 索引失效的场景与原因
  • k8s+isulad 国产化技术栈云原生技术栈搭建2-crictl
  • Linux进程启动后,监听端口几分钟后消失之问题分析
  • MySQL 事务原理 + ACID笔记
  • HiveMQ核心架构思维导图2024.9(Community Edition)
  • Educational Codeforces Round 171 (Rated for Div. 2)
  • 06.Redis 配置文件说明