当前位置: 首页 > news >正文

2025 腾讯广告算法大赛 Baseline 项目解析

项目概述

2025 腾讯广告算法大赛 Baseline,一个简单的序列推荐系统,主要用于建模用户和物品的交互序列,并利用多模态特征(文本、图像等 embedding)来提升推荐效果。

核心文件功能

1. main.py - 主训练脚本

  • 负责模型训练的整体流程
  • 包含参数解析、数据加载、模型初始化、训练循环等
  • 支持断点续训和仅推理模式
  • 使用 TensorBoard 记录训练日志
main.py 代码
import argparse
import json
import os
import time
from pathlib import Pathimport numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom dataset import MyDataset
from model import BaselineModeldef get_args():parser = argparse.ArgumentParser()# Train paramsparser.add_argument('--batch_size', default=128, type=int)parser.add_argument('--lr', default=0.001, type=float)parser.add_argument('--maxlen', default=101, type=int)# Baseline Model constructionparser.add_argument('--hidden_units', default=32, type=int)parser.add_argument('--num_blocks', default=1, type=int)parser.add_argument('--num_epochs', default=3, type=int)parser.add_argument('--num_heads', default=1, type=int)parser.add_argument('--dropout_rate', default=0.2, type=float)parser.add_argument('--l2_emb', default=0.0, type=float)parser.add_argument('--device', default='cuda', type=str)parser.add_argument('--inference_only', action='store_true')parser.add_argument('--state_dict_path', default=None, type=str)parser.add_argument('--norm_first', action='store_true')# MMemb Feature IDparser.add_argument('--mm_emb_id', nargs='+', default=['81'], type=str, choices=[str(s) for s in range(81, 87)])args = parser.parse_args()return argsif __name__ == '__main__':Path(os.environ.get('TRAIN_LOG_PATH')).mkdir(parents=True, exist_ok=True)Path(os.environ.get('TRAIN_TF_EVENTS_PATH')).mkdir(parents=True, exist_ok=True)log_file = open(Path(os.environ.get('TRAIN_LOG_PATH'), 'train.log'), 'w')writer = SummaryWriter(os.environ.get('TRAIN_TF_EVENTS_PATH'))# global datasetdata_path = os.environ.get('TRAIN_DATA_PATH')args = get_args()dataset = MyDataset(data_path, args)train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn)valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)usernum, itemnum = dataset.usernum, dataset.itemnumfeat_statistics, feat_types = dataset.feat_statistics, dataset.feature_typesmodel = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)for name, param in model.named_parameters():try:torch.nn.init.xavier_normal_(param.data)except Exception:passmodel.pos_emb.weight.data[0, :] = 0model.item_emb.weight.data[0, :] = 0model.user_emb.weight.data[0, :] = 0for k in model.sparse_emb:model.sparse_emb[k].weight.data[0, :] = 0epoch_start_idx = 1if args.state_dict_path is not None:try:model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6 :]epoch_start_idx = int(tail[: tail.find('.')]) + 1except:print('failed loading state_dicts, pls check file path: ', end="")print(args.state_dict_path)raise RuntimeError('failed loading state_dicts, pls check file path!')bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))best_val_ndcg, best_val_hr = 0.0, 0.0best_test_ndcg, best_test_hr = 0.0, 0.0T = 0.0t0 = time.time()global_step = 0print("Start training")for epoch in range(epoch_start_idx, args.num_epochs + 1):model.train()if args.inference_only:breakfor step, batch in tqdm(enumerate(train_loader), total=len(train_loader)):seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batchseq = seq.to(args.device)pos = pos.to(args.device)neg = neg.to(args.device)pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)optimizer.zero_grad()indices = np.where(next_token_type == 1)loss = bce_criterion(pos_logits[indices], pos_labels[indices])loss += bce_criterion(neg_logits[indices], neg_labels[indices])log_json = json.dumps({'global_step': global_step, 'loss': loss.item(), 'epoch': epoch, 'time': time.time()})log_file.write(log_json + '\n')log_file.flush()print(log_json)writer.add_scalar('Loss/train', loss.item(), global_step)global_step += 1for param in model.item_emb.parameters():loss += args.l2_emb * torch.norm(param)loss.backward()optimizer.step()model.eval()valid_loss_sum = 0for step, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batchseq = seq.to(args.device)pos = pos.to(args.device)neg = neg.to(args.device)pos_logits, neg_logits = model(seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat)pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)indices = np.where(next_token_type == 1)loss = bce_criterion(pos_logits[indices], pos_labels[indices])loss += bce_criterion(neg_logits[indices], neg_labels[indices])valid_loss_sum += loss.item()valid_loss_sum /= len(valid_loader)writer.add_scalar('Loss/valid', valid_loss_sum, global_step)save_dir = Path(os.environ.get('TRAIN_CKPT_PATH'), f"global_step{global_step}.valid_loss={valid_loss_sum:.4f}")save_dir.mkdir(parents=True, exist_ok=True)torch.save(model.state_dict(), save_dir / "model.pt")print("Done")writer.close()log_file.close()

2. model.py - 核心模型实现

BaselineModel - 主推荐模型

基于 Transformer 的序列推荐模型,具有以下特点:

模型架构

  • 使用 FlashMultiHeadAttention 实现高效的多头注意力机制
  • 采用 PointWiseFeedForward 作为前馈网络
  • 支持多种特征类型:稀疏特征、数组特征、连续特征、多模态 embedding 特征

特征处理

  • 用户特征:稀疏特征 (103,104,105,109)、数组特征 (106,107,108,110)
  • 物品特征:稀疏特征 (100,117,111 等)、多模态 embedding 特征 (81-86)
  • 通过 feat2emb 方法将不同类型特征转换为统一的 embedding 表示

核心方法

  • log2feats:将用户序列转换为特征表示
  • forward:训练时计算正负样本的 logits
  • predict:推理时生成用户表征
  • save_item_emb:保存物品 embedding 用于检索
model.py 代码
from pathlib import Pathimport numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdmfrom dataset import save_embclass FlashMultiHeadAttention(torch.nn.Module):def __init__(self, hidden_units, num_heads, dropout_rate):super(FlashMultiHeadAttention, self).__init__()self.hidden_units = hidden_unitsself.num_heads = num_headsself.head_dim = hidden_units // num_headsself.dropout_rate = dropout_rateassert hidden_units % num_heads == 0, "hidden_units must be divisible by num_heads"self.q_linear = torch.nn.Linear(hidden_units, hidden_units)self.k_linear = torch.nn.Linear(hidden_units, hidden_units)self.v_linear = torch.nn.Linear(hidden_units, hidden_units)self.out_linear = torch.nn.Linear(hidden_units, hidden_units)def forward(self, query, key, value, attn_mask=None):batch_size, seq_len, _ = query.size()# 计算Q, K, VQ = self.q_linear(query)K = self.k_linear(key)V = self.v_linear(value)# reshape为multi-head格式Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)if hasattr(F, 'scaled_dot_product_attention'):# PyTorch 2.0+ 使用内置的Flash Attentionattn_output = F.scaled_dot_product_attention(Q, K, V, dropout_p=self.dropout_rate if self.training else 0.0, attn_mask=attn_mask.unsqueeze(1))else:# 降级到标准注意力机制scale = (self.head_dim) ** -0.5scores = torch.matmul(Q, K.transpose(-2, -1)) * scaleif attn_mask is not None:scores.masked_fill_(attn_mask.unsqueeze(1).logical_not(), float('-inf'))attn_weights = F.softmax(scores, dim=-1)attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training)attn_output = torch.matmul(attn_weights, V)# reshape回原来的格式attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_units)# 最终的线性变换output = self.out_linear(attn_output)return output, Noneclass PointWiseFeedForward(torch.nn.Module):def __init__(self, hidden_units, dropout_rate):super(PointWiseFeedForward, self).__init__()self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)self.dropout1 = torch.nn.Dropout(p=dropout_rate)self.relu = torch.nn.ReLU()self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)self.dropout2 = torch.nn.Dropout(p=dropout_rate)def forward(self, inputs):outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))outputs = outputs.transpose(-1, -2)  # as Conv1D requires (N, C, Length)return outputsclass BaselineModel(torch.nn.Module):"""Args:user_num: 用户数量item_num: 物品数量feat_statistics: 特征统计信息,key为特征ID,value为特征数量feat_types: 各个特征的特征类型,key为特征类型名称,value为包含的特征ID列表,包括user和item的sparse, array, emb, continual类型args: 全局参数Attributes:user_num: 用户数量item_num: 物品数量dev: 设备norm_first: 是否先归一化maxlen: 序列最大长度item_emb: Item Embedding Tableuser_emb: User Embedding Tablesparse_emb: 稀疏特征Embedding Tableemb_transform: 多模态特征的线性变换userdnn: 用户特征拼接后经过的全连接层itemdnn: 物品特征拼接后经过的全连接层"""def __init__(self, user_num, item_num, feat_statistics, feat_types, args):  #super(BaselineModel, self).__init__()self.user_num = user_numself.item_num = item_numself.dev = args.deviceself.norm_first = args.norm_firstself.maxlen = args.maxlen# TODO: loss += args.l2_emb for regularizing embedding vectors during training# https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorchself.item_emb = torch.nn.Embedding(self.item_num + 1, args.hidden_units, padding_idx=0)self.user_emb = torch.nn.Embedding(self.user_num + 1, args.hidden_units, padding_idx=0)self.pos_emb = torch.nn.Embedding(2 * args.maxlen + 1, args.hidden_units, padding_idx=0)self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)self.sparse_emb = torch.nn.ModuleDict()self.emb_transform = torch.nn.ModuleDict()self.attention_layernorms = torch.nn.ModuleList()  # to be Q for self-attentionself.attention_layers = torch.nn.ModuleList()self.forward_layernorms = torch.nn.ModuleList()self.forward_layers = torch.nn.ModuleList()self._init_feat_info(feat_statistics, feat_types)userdim = args.hidden_units * (len(self.USER_SPARSE_FEAT) + 1 + len(self.USER_ARRAY_FEAT)) + len(self.USER_CONTINUAL_FEAT)itemdim = (args.hidden_units * (len(self.ITEM_SPARSE_FEAT) + 1 + len(self.ITEM_ARRAY_FEAT))+ len(self.ITEM_CONTINUAL_FEAT)+ args.hidden_units * len(self.ITEM_EMB_FEAT))self.userdnn = torch.nn.Linear(userdim, args.hidden_units)self.itemdnn = torch.nn.Linear(itemdim, args.hidden_units)self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)for _ in range(args.num_blocks):new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)self.attention_layernorms.append(new_attn_layernorm)new_attn_layer = FlashMultiHeadAttention(args.hidden_units, args.num_heads, args.dropout_rate)  # 优化:用FlashAttention替代标准Attentionself.attention_layers.append(new_attn_layer)new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)self.forward_layernorms.append(new_fwd_layernorm)new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)self.forward_layers.append(new_fwd_layer)for k in self.USER_SPARSE_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.USER_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_SPARSE_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_SPARSE_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_ARRAY_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.ITEM_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.USER_ARRAY_FEAT:self.sparse_emb[k] = torch.nn.Embedding(self.USER_ARRAY_FEAT[k] + 1, args.hidden_units, padding_idx=0)for k in self.ITEM_EMB_FEAT:self.emb_transform[k] = torch.nn.Linear(self.ITEM_EMB_FEAT[k], args.hidden_units)def _init_feat_info(self, feat_statistics, feat_types):"""将特征统计信息(特征数量)按特征类型分组产生不同的字典,方便声明稀疏特征的Embedding TableArgs:feat_statistics: 特征统计信息,key为特征ID,value为特征数量feat_types: 各个特征的特征类型,key为特征类型名称,value为包含的特征ID列表,包括user和item的sparse, array, emb, continual类型"""self.USER_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['user_sparse']}self.USER_CONTINUAL_FEAT = feat_types['user_continual']self.ITEM_SPARSE_FEAT = {k: feat_statistics[k] for k in feat_types['item_sparse']}self.ITEM_CONTINUAL_FEAT = feat_types['item_continual']self.USER_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['user_array']}self.ITEM_ARRAY_FEAT = {k: feat_statistics[k] for k in feat_types['item_array']}EMB_SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}self.ITEM_EMB_FEAT = {k: EMB_SHAPE_DICT[k] for k in feat_types['item_emb']}  # 记录的是不同多模态特征的维度def feat2tensor(self, seq_feature, k):"""Args:seq_feature: 序列特征list,每个元素为当前时刻的特征字典,形状为 [batch_size, maxlen]k: 特征IDReturns:batch_data: 特征值的tensor,形状为 [batch_size, maxlen, max_array_len(if array)]"""batch_size = len(seq_feature)if k in self.ITEM_ARRAY_FEAT or k in self.USER_ARRAY_FEAT:# 如果特征是Array类型,需要先对array进行padding,然后转换为tensormax_array_len = 0max_seq_len = 0for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]max_seq_len = max(max_seq_len, len(seq_data))max_array_len = max(max_array_len, max(len(item_data) for item_data in seq_data))batch_data = np.zeros((batch_size, max_seq_len, max_array_len), dtype=np.int64)for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]for j, item_data in enumerate(seq_data):actual_len = min(len(item_data), max_array_len)batch_data[i, j, :actual_len] = item_data[:actual_len]return torch.from_numpy(batch_data).to(self.dev)else:# 如果特征是Sparse类型,直接转换为tensormax_seq_len = max(len(seq_feature[i]) for i in range(batch_size))batch_data = np.zeros((batch_size, max_seq_len), dtype=np.int64)for i in range(batch_size):seq_data = [item[k] for item in seq_feature[i]]batch_data[i] = seq_datareturn torch.from_numpy(batch_data).to(self.dev)def feat2emb(self, seq, feature_array, mask=None, include_user=False):"""Args:seq: 序列IDfeature_array: 特征list,每个元素为当前时刻的特征字典mask: 掩码,1表示item,2表示userinclude_user: 是否处理用户特征,在两种情况下不打开:1) 训练时在转换正负样本的特征时(因为正负样本都是item);2) 生成候选库item embedding时。Returns:seqs_emb: 序列特征的Embedding"""seq = seq.to(self.dev)# pre-compute embeddingif include_user:user_mask = (mask == 2).to(self.dev)item_mask = (mask == 1).to(self.dev)user_embedding = self.user_emb(user_mask * seq)item_embedding = self.item_emb(item_mask * seq)item_feat_list = [item_embedding]user_feat_list = [user_embedding]else:item_embedding = self.item_emb(seq)item_feat_list = [item_embedding]# batch-process all feature typesall_feat_types = [(self.ITEM_SPARSE_FEAT, 'item_sparse', item_feat_list),(self.ITEM_ARRAY_FEAT, 'item_array', item_feat_list),(self.ITEM_CONTINUAL_FEAT, 'item_continual', item_feat_list),]if include_user:all_feat_types.extend([(self.USER_SPARSE_FEAT, 'user_sparse', user_feat_list),(self.USER_ARRAY_FEAT, 'user_array', user_feat_list),(self.USER_CONTINUAL_FEAT, 'user_continual', user_feat_list),])# batch-process each feature typefor feat_dict, feat_type, feat_list in all_feat_types:if not feat_dict:continuefor k in feat_dict:tensor_feature = self.feat2tensor(feature_array, k)if feat_type.endswith('sparse'):feat_list.append(self.sparse_emb[k](tensor_feature))elif feat_type.endswith('array'):feat_list.append(self.sparse_emb[k](tensor_feature).sum(2))elif feat_type.endswith('continual'):feat_list.append(tensor_feature.unsqueeze(2))for k in self.ITEM_EMB_FEAT:# collect all data to numpy, then batch-convertbatch_size = len(feature_array)emb_dim = self.ITEM_EMB_FEAT[k]seq_len = len(feature_array[0])# pre-allocate tensorbatch_emb_data = np.zeros((batch_size, seq_len, emb_dim), dtype=np.float32)for i, seq in enumerate(feature_array):for j, item in enumerate(seq):if k in item:batch_emb_data[i, j] = item[k]# batch-convert and transfer to GPUtensor_feature = torch.from_numpy(batch_emb_data).to(self.dev)item_feat_list.append(self.emb_transform[k](tensor_feature))# merge featuresall_item_emb = torch.cat(item_feat_list, dim=2)all_item_emb = torch.relu(self.itemdnn(all_item_emb))if include_user:all_user_emb = torch.cat(user_feat_list, dim=2)all_user_emb = torch.relu(self.userdnn(all_user_emb))seqs_emb = all_item_emb + all_user_embelse:seqs_emb = all_item_embreturn seqs_embdef log2feats(self, log_seqs, mask, seq_feature):"""Args:log_seqs: 序列IDmask: token类型掩码,1表示item token,2表示user tokenseq_feature: 序列特征list,每个元素为当前时刻的特征字典Returns:seqs_emb: 序列的Embedding,形状为 [batch_size, maxlen, hidden_units]"""batch_size = log_seqs.shape[0]maxlen = log_seqs.shape[1]seqs = self.feat2emb(log_seqs, seq_feature, mask=mask, include_user=True)seqs *= self.item_emb.embedding_dim**0.5poss = torch.arange(1, maxlen + 1, device=self.dev).unsqueeze(0).expand(batch_size, -1).clone()poss *= log_seqs != 0seqs += self.pos_emb(poss)seqs = self.emb_dropout(seqs)maxlen = seqs.shape[1]ones_matrix = torch.ones((maxlen, maxlen), dtype=torch.bool, device=self.dev)attention_mask_tril = torch.tril(ones_matrix)attention_mask_pad = (mask != 0).to(self.dev)attention_mask = attention_mask_tril.unsqueeze(0) & attention_mask_pad.unsqueeze(1)for i in range(len(self.attention_layers)):if self.norm_first:x = self.attention_layernorms[i](seqs)mha_outputs, _ = self.attention_layers[i](x, x, x, attn_mask=attention_mask)seqs = seqs + mha_outputsseqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs))else:mha_outputs, _ = self.attention_layers[i](seqs, seqs, seqs, attn_mask=attention_mask)seqs = self.attention_layernorms[i](seqs + mha_outputs)seqs = self.forward_layernorms[i](seqs + self.forward_layers[i](seqs))log_feats = self.last_layernorm(seqs)return log_featsdef forward(self, user_item, pos_seqs, neg_seqs, mask, next_mask, next_action_type, seq_feature, pos_feature, neg_feature):"""训练时调用,计算正负样本的logitsArgs:user_item: 用户序列IDpos_seqs: 正样本序列IDneg_seqs: 负样本序列IDmask: token类型掩码,1表示item token,2表示user tokennext_mask: 下一个token类型掩码,1表示item token,2表示user tokennext_action_type: 下一个token动作类型,0表示曝光,1表示点击seq_feature: 序列特征list,每个元素为当前时刻的特征字典pos_feature: 正样本特征list,每个元素为当前时刻的特征字典neg_feature: 负样本特征list,每个元素为当前时刻的特征字典Returns:pos_logits: 正样本logits,形状为 [batch_size, maxlen]neg_logits: 负样本logits,形状为 [batch_size, maxlen]"""log_feats = self.log2feats(user_item, mask, seq_feature)loss_mask = (next_mask == 1).to(self.dev)pos_embs = self.feat2emb(pos_seqs, pos_feature, include_user=False)neg_embs = self.feat2emb(neg_seqs, neg_feature, include_user=False)pos_logits = (log_feats * pos_embs).sum(dim=-1)neg_logits = (log_feats * neg_embs).sum(dim=-1)pos_logits = pos_logits * loss_maskneg_logits = neg_logits * loss_maskreturn pos_logits, neg_logitsdef predict(self, log_seqs, seq_feature, mask):"""计算用户序列的表征Args:log_seqs: 用户序列IDseq_feature: 序列特征list,每个元素为当前时刻的特征字典mask: token类型掩码,1表示item token,2表示user tokenReturns:final_feat: 用户序列的表征,形状为 [batch_size, hidden_units]"""log_feats = self.log2feats(log_seqs, mask, seq_feature)final_feat = log_feats[:, -1, :]return final_featdef save_item_emb(self, item_ids, retrieval_ids, feat_dict, save_path, batch_size=1024):"""生成候选库item embedding,用于检索Args:item_ids: 候选item ID(re-id形式)retrieval_ids: 候选item ID(检索ID,从0开始编号,检索脚本使用)feat_dict: 训练集所有item特征字典,key为特征ID,value为特征值save_path: 保存路径batch_size: 批次大小"""all_embs = []for start_idx in tqdm(range(0, len(item_ids), batch_size), desc="Saving item embeddings"):end_idx = min(start_idx + batch_size, len(item_ids))item_seq = torch.tensor(item_ids[start_idx:end_idx], device=self.dev).unsqueeze(0)batch_feat = []for i in range(start_idx, end_idx):batch_feat.append(feat_dict[i])batch_feat = np.array(batch_feat, dtype=object)batch_emb = self.feat2emb(item_seq, [batch_feat], include_user=False).squeeze(0)all_embs.append(batch_emb.detach().cpu().numpy().astype(np.float32))# 合并所有批次的结果并保存final_ids = np.array(retrieval_ids, dtype=np.uint64).reshape(-1, 1)final_embs = np.concatenate(all_embs, axis=0)save_emb(final_embs, Path(save_path, 'embedding.fbin'))save_emb(final_ids, Path(save_path, 'id.u64bin'))

3. dataset.py - 数据处理

MyDataset - 训练数据集
  • 处理用户行为序列数据,支持用户和物品交替出现的序列格式
  • 实现高效的数据加载,使用文件偏移量进行随机访问
  • 支持多种特征类型的 padding 和缺失值填充
  • 实现负采样机制用于训练
MyTestDataset - 测试数据集
  • 继承自训练数据集,专门用于推理阶段
  • 处理冷启动问题(训练时未见过的特征值)
dataset.py 代码
import json
import pickle
import struct
from pathlib import Pathimport numpy as np
import torch
from tqdm import tqdmclass MyDataset(torch.utils.data.Dataset):"""用户序列数据集Args:data_dir: 数据文件目录args: 全局参数Attributes:data_dir: 数据文件目录maxlen: 最大长度item_feat_dict: 物品特征字典mm_emb_ids: 激活的mm_emb特征IDmm_emb_dict: 多模态特征字典itemnum: 物品数量usernum: 用户数量indexer_i_rev: 物品索引字典 (reid -> item_id)indexer_u_rev: 用户索引字典 (reid -> user_id)indexer: 索引字典feature_default_value: 特征缺省值feature_types: 特征类型,分为user和item的sparse, array, emb, continual类型feat_statistics: 特征统计信息,包括user和item的特征数量"""def __init__(self, data_dir, args):"""初始化数据集"""super().__init__()self.data_dir = Path(data_dir)self._load_data_and_offsets()self.maxlen = args.maxlenself.mm_emb_ids = args.mm_emb_idself.item_feat_dict = json.load(open(Path(data_dir, "item_feat_dict.json"), 'r'))self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_ids)with open(self.data_dir / 'indexer.pkl', 'rb') as ff:indexer = pickle.load(ff)self.itemnum = len(indexer['i'])self.usernum = len(indexer['u'])self.indexer_i_rev = {v: k for k, v in indexer['i'].items()}self.indexer_u_rev = {v: k for k, v in indexer['u'].items()}self.indexer = indexerself.feature_default_value, self.feature_types, self.feat_statistics = self._init_feat_info()def _load_data_and_offsets(self):"""加载用户序列数据和每一行的文件偏移量(预处理好的), 用于快速随机访问数据并I/O"""self.data_file = open(self.data_dir / "seq.jsonl", 'rb')with open(Path(self.data_dir, 'seq_offsets.pkl'), 'rb') as f:self.seq_offsets = pickle.load(f)def _load_user_data(self, uid):"""从数据文件中加载单个用户的数据Args:uid: 用户ID(reid)Returns:data: 用户序列数据,格式为[(user_id, item_id, user_feat, item_feat, action_type, timestamp)]"""self.data_file.seek(self.seq_offsets[uid])line = self.data_file.readline()data = json.loads(line)return datadef _random_neq(self, l, r, s):"""生成一个不在序列s中的随机整数, 用于训练时的负采样Args:l: 随机整数的最小值r: 随机整数的最大值s: 序列Returns:t: 不在序列s中的随机整数"""t = np.random.randint(l, r)while t in s or str(t) not in self.item_feat_dict:t = np.random.randint(l, r)return tdef __getitem__(self, uid):"""获取单个用户的数据,并进行padding处理,生成模型需要的数据格式Args:uid: 用户ID(reid)Returns:seq: 用户序列IDpos: 正样本ID(即下一个真实访问的item)neg: 负样本IDtoken_type: 用户序列类型,1表示item,2表示usernext_token_type: 下一个token类型,1表示item,2表示userseq_feat: 用户序列特征,每个元素为字典,key为特征ID,value为特征值pos_feat: 正样本特征,每个元素为字典,key为特征ID,value为特征值neg_feat: 负样本特征,每个元素为字典,key为特征ID,value为特征值"""user_sequence = self._load_user_data(uid)  # 动态加载用户数据ext_user_sequence = []for record_tuple in user_sequence:u, i, user_feat, item_feat, action_type, _ = record_tupleif u and user_feat:ext_user_sequence.insert(0, (u, user_feat, 2, action_type))if i and item_feat:ext_user_sequence.append((i, item_feat, 1, action_type))seq = np.zeros([self.maxlen + 1], dtype=np.int32)pos = np.zeros([self.maxlen + 1], dtype=np.int32)neg = np.zeros([self.maxlen + 1], dtype=np.int32)token_type = np.zeros([self.maxlen + 1], dtype=np.int32)next_token_type = np.zeros([self.maxlen + 1], dtype=np.int32)next_action_type = np.zeros([self.maxlen + 1], dtype=np.int32)seq_feat = np.empty([self.maxlen + 1], dtype=object)pos_feat = np.empty([self.maxlen + 1], dtype=object)neg_feat = np.empty([self.maxlen + 1], dtype=object)nxt = ext_user_sequence[-1]idx = self.maxlents = set()for record_tuple in ext_user_sequence:if record_tuple[2] == 1 and record_tuple[0]:ts.add(record_tuple[0])# left-padding, 从后往前遍历,将用户序列填充到maxlen+1的长度for record_tuple in reversed(ext_user_sequence[:-1]):i, feat, type_, act_type = record_tuplenext_i, next_feat, next_type, next_act_type = nxtfeat = self.fill_missing_feat(feat, i)next_feat = self.fill_missing_feat(next_feat, next_i)seq[idx] = itoken_type[idx] = type_next_token_type[idx] = next_typeif next_act_type is not None:next_action_type[idx] = next_act_typeseq_feat[idx] = featif next_type == 1 and next_i != 0:pos[idx] = next_ipos_feat[idx] = next_featneg_id = self._random_neq(1, self.itemnum + 1, ts)neg[idx] = neg_idneg_feat[idx] = self.fill_missing_feat(self.item_feat_dict[str(neg_id)], neg_id)nxt = record_tupleidx -= 1if idx == -1:breakseq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)pos_feat = np.where(pos_feat == None, self.feature_default_value, pos_feat)neg_feat = np.where(neg_feat == None, self.feature_default_value, neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_featdef __len__(self):"""返回数据集长度,即用户数量Returns:usernum: 用户数量"""return len(self.seq_offsets)def _init_feat_info(self):"""初始化特征信息, 包括特征缺省值和特征类型Returns:feat_default_value: 特征缺省值,每个元素为字典,key为特征ID,value为特征缺省值feat_types: 特征类型,key为特征类型名称,value为包含的特征ID列表"""feat_default_value = {}feat_statistics = {}feat_types = {}feat_types['user_sparse'] = ['103', '104', '105', '109']feat_types['item_sparse'] = ['100','117','111','118','101','102','119','120','114','112','121','115','122','116',]feat_types['item_array'] = []feat_types['user_array'] = ['106', '107', '108', '110']feat_types['item_emb'] = self.mm_emb_idsfeat_types['user_continual'] = []feat_types['item_continual'] = []for feat_id in feat_types['user_sparse']:feat_default_value[feat_id] = 0feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['item_sparse']:feat_default_value[feat_id] = 0feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['item_array']:feat_default_value[feat_id] = [0]feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['user_array']:feat_default_value[feat_id] = [0]feat_statistics[feat_id] = len(self.indexer['f'][feat_id])for feat_id in feat_types['user_continual']:feat_default_value[feat_id] = 0for feat_id in feat_types['item_continual']:feat_default_value[feat_id] = 0for feat_id in feat_types['item_emb']:feat_default_value[feat_id] = np.zeros(list(self.mm_emb_dict[feat_id].values())[0].shape[0], dtype=np.float32)return feat_default_value, feat_types, feat_statisticsdef fill_missing_feat(self, feat, item_id):"""对于原始数据中缺失的特征进行填充缺省值Args:feat: 特征字典item_id: 物品IDReturns:filled_feat: 填充后的特征字典"""if feat == None:feat = {}filled_feat = {}for k in feat.keys():filled_feat[k] = feat[k]all_feat_ids = []for feat_type in self.feature_types.values():all_feat_ids.extend(feat_type)missing_fields = set(all_feat_ids) - set(feat.keys())for feat_id in missing_fields:filled_feat[feat_id] = self.feature_default_value[feat_id]for feat_id in self.feature_types['item_emb']:if item_id != 0 and self.indexer_i_rev[item_id] in self.mm_emb_dict[feat_id]:if type(self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]) == np.ndarray:filled_feat[feat_id] = self.mm_emb_dict[feat_id][self.indexer_i_rev[item_id]]return filled_feat@staticmethoddef collate_fn(batch):"""Args:batch: 多个__getitem__返回的数据Returns:seq: 用户序列ID, torch.Tensor形式pos: 正样本ID, torch.Tensor形式neg: 负样本ID, torch.Tensor形式token_type: 用户序列类型, torch.Tensor形式next_token_type: 下一个token类型, torch.Tensor形式seq_feat: 用户序列特征, list形式pos_feat: 正样本特征, list形式neg_feat: 负样本特征, list形式"""seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = zip(*batch)seq = torch.from_numpy(np.array(seq))pos = torch.from_numpy(np.array(pos))neg = torch.from_numpy(np.array(neg))token_type = torch.from_numpy(np.array(token_type))next_token_type = torch.from_numpy(np.array(next_token_type))next_action_type = torch.from_numpy(np.array(next_action_type))seq_feat = list(seq_feat)pos_feat = list(pos_feat)neg_feat = list(neg_feat)return seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_featclass MyTestDataset(MyDataset):"""测试数据集"""def __init__(self, data_dir, args):super().__init__(data_dir, args)def _load_data_and_offsets(self):self.data_file = open(self.data_dir / "predict_seq.jsonl", 'rb')with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:self.seq_offsets = pickle.load(f)def _process_cold_start_feat(self, feat):"""处理冷启动特征。训练集未出现过的特征value为字符串,默认转换为0.可设计替换为更好的方法。"""processed_feat = {}for feat_id, feat_value in feat.items():if type(feat_value) == list:value_list = []for v in feat_value:if type(v) == str:value_list.append(0)else:value_list.append(v)processed_feat[feat_id] = value_listelif type(feat_value) == str:processed_feat[feat_id] = 0else:processed_feat[feat_id] = feat_valuereturn processed_featdef __getitem__(self, uid):"""获取单个用户的数据,并进行padding处理,生成模型需要的数据格式Args:uid: 用户在self.data_file中储存的行号Returns:seq: 用户序列IDtoken_type: 用户序列类型,1表示item,2表示userseq_feat: 用户序列特征,每个元素为字典,key为特征ID,value为特征值user_id: user_id eg. user_xxxxxx ,便于后面对照答案"""user_sequence = self._load_user_data(uid)  # 动态加载用户数据ext_user_sequence = []for record_tuple in user_sequence:u, i, user_feat, item_feat, _, _ = record_tupleif u:if type(u) == str:  # 如果是字符串,说明是user_iduser_id = uelse:  # 如果是int,说明是re_iduser_id = self.indexer_u_rev[u]if u and user_feat:if type(u) == str:u = 0if user_feat:user_feat = self._process_cold_start_feat(user_feat)ext_user_sequence.insert(0, (u, user_feat, 2))if i and item_feat:# 序列对于训练时没见过的item,不会直接赋0,而是保留creative_id,creative_id远大于训练时的itemnumif i > self.itemnum:i = 0if item_feat:item_feat = self._process_cold_start_feat(item_feat)ext_user_sequence.append((i, item_feat, 1))seq = np.zeros([self.maxlen + 1], dtype=np.int32)token_type = np.zeros([self.maxlen + 1], dtype=np.int32)seq_feat = np.empty([self.maxlen + 1], dtype=object)idx = self.maxlents = set()for record_tuple in ext_user_sequence:if record_tuple[2] == 1 and record_tuple[0]:ts.add(record_tuple[0])for record_tuple in reversed(ext_user_sequence[:-1]):i, feat, type_ = record_tuplefeat = self.fill_missing_feat(feat, i)seq[idx] = itoken_type[idx] = type_seq_feat[idx] = featidx -= 1if idx == -1:breakseq_feat = np.where(seq_feat == None, self.feature_default_value, seq_feat)return seq, token_type, seq_feat, user_iddef __len__(self):"""Returns:len(self.seq_offsets): 用户数量"""with open(Path(self.data_dir, 'predict_seq_offsets.pkl'), 'rb') as f:temp = pickle.load(f)return len(temp)@staticmethoddef collate_fn(batch):"""将多个__getitem__返回的数据拼接成一个batchArgs:batch: 多个__getitem__返回的数据Returns:seq: 用户序列ID, torch.Tensor形式token_type: 用户序列类型, torch.Tensor形式seq_feat: 用户序列特征, list形式user_id: user_id, str"""seq, token_type, seq_feat, user_id = zip(*batch)seq = torch.from_numpy(np.array(seq))token_type = torch.from_numpy(np.array(token_type))seq_feat = list(seq_feat)return seq, token_type, seq_feat, user_iddef save_emb(emb, save_path):"""将Embedding保存为二进制文件Args:emb: 要保存的Embedding,形状为 [num_points, num_dimensions]save_path: 保存路径"""num_points = emb.shape[0]  # 数据点数量num_dimensions = emb.shape[1]  # 向量的维度print(f'saving {save_path}')with open(Path(save_path), 'wb') as f:f.write(struct.pack('II', num_points, num_dimensions))emb.tofile(f)def load_mm_emb(mm_path, feat_ids):"""加载多模态特征EmbeddingArgs:mm_path: 多模态特征Embedding路径feat_ids: 要加载的多模态特征ID列表Returns:mm_emb_dict: 多模态特征Embedding字典,key为特征ID,value为特征Embedding字典(key为item ID,value为Embedding)"""SHAPE_DICT = {"81": 32, "82": 1024, "83": 3584, "84": 4096, "85": 3584, "86": 3584}mm_emb_dict = {}for feat_id in tqdm(feat_ids, desc='Loading mm_emb'):shape = SHAPE_DICT[feat_id]emb_dict = {}if feat_id != '81':try:base_path = Path(mm_path, f'emb_{feat_id}_{shape}')for json_file in base_path.glob('*.json'):with open(json_file, 'r', encoding='utf-8') as file:for line in file:data_dict_origin = json.loads(line.strip())insert_emb = data_dict_origin['emb']if isinstance(insert_emb, list):insert_emb = np.array(insert_emb, dtype=np.float32)data_dict = {data_dict_origin['anonymous_cid']: insert_emb}emb_dict.update(data_dict)except Exception as e:print(f"transfer error: {e}")if feat_id == '81':with open(Path(mm_path, f'emb_{feat_id}_{shape}.pkl'), 'rb') as f:emb_dict = pickle.load(f)mm_emb_dict[feat_id] = emb_dictprint(f'Loaded #{feat_id} mm_emb')return mm_emb_dict

4. model_rqvae.py - 多模态特征压缩

实现了 RQ-VAE(Residual Quantized Variational AutoEncoder)框架,用于将高维多模态 embedding 转换为离散的语义 ID:

核心组件

  • RQEncoder/RQDecoder:编码器和解码器
  • VQEmbedding:向量量化模块,支持 K-means 初始化
  • RQ:残差量化器,实现多级量化
  • RQVAE:完整的 RQ-VAE 模型

量化方法

  • 支持标准 K-means 和平衡 K-means 聚类
  • 使用余弦距离或 L2 距离进行向量量化
  • 通过残差量化实现更精确的特征表示
model_rqvae.py 代码
"""
选手可参考以下流程,使用提供的 RQ-VAE 框架代码将多模态emb数据转换为Semantic Id:
1. 使用 MmEmbDataset 读取不同特征 ID 的多模态emb数据.
2. 训练 RQ-VAE 模型, 训练完成后将数据转换为Semantic Id.
3. 参照 Item Sparse 特征格式处理Semantic Id,作为新特征加入Baseline模型训练.
"""import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans# class MmEmbDataset(torch.utils.data.Dataset):
#     """
#     Build Dataset for RQ-VAE Training#     Args:
#         data_dir = os.environ.get('TRAIN_DATA_PATH')
#         feature_id = MM emb ID
#     """#     def __init__(self, data_dir, feature_id):
#         super().__init__()
#         self.data_dir = Path(data_dir)
#         self.mm_emb_id = [feature_id]
#         self.mm_emb_dict = load_mm_emb(Path(data_dir, "creative_emb"), self.mm_emb_id)#         self.mm_emb = self.mm_emb_dict[self.mm_emb_id[0]]
#         self.tid_list, self.emb_list = list(self.mm_emb.keys()), list(self.mm_emb.values())
#         self.emb_list = [torch.tensor(emb, dtype=torch.float32) for emb in self.emb_list]#         assert len(self.tid_list) == len(self.emb_list)
#         self.item_cnt = len(self.tid_list)#     def __getitem__(self, index):
#         tid = torch.tensor(self.tid_list[index], dtype=torch.long)
#         emb = self.emb_list[index]
#         return tid, emb#     def __len__(self):
#         return self.item_cnt#     @staticmethod
#     def collate_fn(batch):
#         tid, emb = zip(*batch)#         tid_batch, emb_batch = torch.stack(tid, dim=0), torch.stack(emb, dim=0)
#         return tid_batch, emb_batch## Kmeans
def kmeans(data, n_clusters, kmeans_iters):"""auto init: n_init = 10 if n_clusters <= 10 else 1"""km = KMeans(n_clusters=n_clusters, max_iter=kmeans_iters, n_init="auto")# sklearn only support cpudata_cpu = data.detach().cpu()np_data = data_cpu.numpy()km.fit(np_data)return torch.tensor(km.cluster_centers_), torch.tensor(km.labels_)## Balanced Kmeans
class BalancedKmeans(torch.nn.Module):def __init__(self, num_clusters: int, kmeans_iters: int, tolerance: float, device: str):super().__init__()self.num_clusters = num_clustersself.kmeans_iters = kmeans_itersself.tolerance = toleranceself.device = deviceself._codebook = Nonedef _compute_distances(self, data):return torch.cdist(data, self._codebook)def _assign_clusters(self, dist):samples_cnt = dist.shape[0]samples_labels = torch.zeros(samples_cnt, dtype=torch.long, device=self.device)clusters_cnt = torch.zeros(self.num_clusters, dtype=torch.long, device=self.device)sorted_indices = torch.argsort(dist, dim=-1)for i in range(samples_cnt):for j in range(self.num_clusters):cluster_idx = sorted_indices[i, j]if clusters_cnt[cluster_idx] < samples_cnt // self.num_clusters:samples_labels[i] = cluster_idxclusters_cnt[cluster_idx] += 1breakreturn samples_labelsdef _update_codebook(self, data, samples_labels):_new_codebook = []for i in range(self.num_clusters):cluster_data = data[samples_labels == i]if len(cluster_data) > 0:_new_codebook.append(cluster_data.mean(dim=0))else:_new_codebook.append(self._codebook[i])return torch.stack(_new_codebook)def fit(self, data):num_emb, codebook_emb_dim = data.shapedata = data.to(self.device)# initialize codebookindices = torch.randperm(num_emb)[: self.num_clusters]self._codebook = data[indices].clone()for _ in range(self.kmeans_iters):dist = self._compute_distances(data)samples_labels = self._assign_clusters(dist)_new_codebook = self._update_codebook(data, samples_labels)if torch.norm(_new_codebook - self._codebook) < self.tolerance:breakself._codebook = _new_codebookreturn self._codebook, samples_labelsdef predict(self, data):data = data.to(self.device)dist = self._compute_distances(data)samples_labels = self._assign_clusters(dist)return samples_labels## Base RQVAE
class RQEncoder(torch.nn.Module):def __init__(self, input_dim: int, hidden_channels: list, latent_dim: int):super().__init__()self.stages = torch.nn.ModuleList()in_dim = input_dimfor out_dim in hidden_channels:stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())self.stages.append(stage)in_dim = out_dimself.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, latent_dim), torch.nn.ReLU()))def forward(self, x):for stage in self.stages:x = stage(x)return xclass RQDecoder(torch.nn.Module):def __init__(self, latent_dim: int, hidden_channels: list, output_dim: int):super().__init__()self.stages = torch.nn.ModuleList()in_dim = latent_dimfor out_dim in hidden_channels:stage = torch.nn.Sequential(torch.nn.Linear(in_dim, out_dim), torch.nn.ReLU())self.stages.append(stage)in_dim = out_dimself.stages.append(torch.nn.Sequential(torch.nn.Linear(in_dim, output_dim), torch.nn.ReLU()))def forward(self, x):for stage in self.stages:x = stage(x)return x## Generate semantic id
class VQEmbedding(torch.nn.Embedding):def __init__(self,num_clusters,codebook_emb_dim: int,kmeans_method: str,kmeans_iters: int,distances_method: str,device: str,):super(VQEmbedding, self).__init__(num_clusters, codebook_emb_dim)self.num_clusters = num_clustersself.codebook_emb_dim = codebook_emb_dimself.kmeans_method = kmeans_methodself.kmeans_iters = kmeans_itersself.distances_method = distances_methodself.device = devicedef _create_codebook(self, data):if self.kmeans_method == 'kmeans':_codebook, _ = kmeans(data, self.num_clusters, self.kmeans_iters)elif self.kmeans_method == 'bkmeans':BKmeans = BalancedKmeans(num_clusters=self.num_clusters, kmeans_iters=self.kmeans_iters, tolerance=1e-4, device=self.device)_codebook, _ = BKmeans.fit(data)else:_codebook = torch.randn(self.num_clusters, self.codebook_emb_dim)_codebook = _codebook.to(self.device)assert _codebook.shape == (self.num_clusters, self.codebook_emb_dim)self.codebook = torch.nn.Parameter(_codebook)@torch.no_grad()def _compute_distances(self, data):_codebook_t = self.codebook.t()assert _codebook_t.shape == (self.codebook_emb_dim, self.num_clusters)assert data.shape[-1] == self.codebook_emb_dimif self.distances_method == 'cosine':data_norm = F.normalize(data, p=2, dim=-1)_codebook_t_norm = F.normalize(_codebook_t, p=2, dim=0)distances = 1 - torch.mm(data_norm, _codebook_t_norm)# l2else:data_norm_sq = data.pow(2).sum(dim=-1, keepdim=True)_codebook_t_norm_sq = _codebook_t.pow(2).sum(dim=0, keepdim=True)distances = torch.addmm(data_norm_sq + _codebook_t_norm_sq, data, _codebook_t, beta=1.0, alpha=-2.0)return distances@torch.no_grad()def _create_semantic_id(self, data):distances = self._compute_distances(data)_semantic_id = torch.argmin(distances, dim=-1)return _semantic_iddef _update_emb(self, _semantic_id):update_emb = super().forward(_semantic_id)return update_embdef forward(self, data):self._create_codebook(data)_semantic_id = self._create_semantic_id(data)update_emb = self._update_emb(_semantic_id)return update_emb, _semantic_id## Residual Quantizer
class RQ(torch.nn.Module):"""Args:num_codebooks, codebook_size, codebook_emb_dim -> Build codebookif_shared_codebook -> If use same codebookkmeans_method, kmeans_iters -> Initialize codebookdistances_method -> Generate semantic_idloss_beta -> Calculate RQ-VAE loss"""def __init__(self,num_codebooks: int,codebook_size: list,codebook_emb_dim,shared_codebook: bool,kmeans_method,kmeans_iters,distances_method,loss_beta: float,device: str,):super().__init__()self.num_codebooks = num_codebooksself.codebook_size = codebook_sizeassert len(self.codebook_size) == self.num_codebooksself.codebook_emb_dim = codebook_emb_dimself.shared_codebook = shared_codebookself.kmeans_method = kmeans_methodself.kmeans_iters = kmeans_itersself.distances_method = distances_methodself.loss_beta = loss_betaself.device = deviceif self.shared_codebook:self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[0],self.codebook_emb_dim,self.kmeans_method,self.kmeans_iters,self.distances_method,self.device,)for _ in range(self.num_codebooks)])else:self.vqmodules = torch.nn.ModuleList([VQEmbedding(self.codebook_size[idx],self.codebook_emb_dim,self.kmeans_method,self.kmeans_iters,self.distances_method,self.device,)for idx in range(self.num_codebooks)])def quantize(self, data):"""Exa:i-th quantize: input[i]( i.e. res[i-1] ) = VQ[i] + res[i]vq_emb_list: [vq1, vq1+vq2, ...]res_emb_list: [res1, res2, ...]semantic_id_list: [vq1_sid, vq2_sid, ...]Returns:vq_emb_list[0] -> [batch_size, codebook_emb_dim]semantic_id_list -> [batch_size, num_codebooks]"""res_emb = data.detach().clone()vq_emb_list, res_emb_list = [], []semantic_id_list = []vq_emb_aggre = torch.zeros_like(data)for i in range(self.num_codebooks):vq_emb, _semantic_id = self.vqmodules[i](res_emb)res_emb -= vq_embvq_emb_aggre += vq_embres_emb_list.append(res_emb)vq_emb_list.append(vq_emb_aggre)semantic_id_list.append(_semantic_id.unsqueeze(dim=-1))semantic_id_list = torch.cat(semantic_id_list, dim=-1)return vq_emb_list, res_emb_list, semantic_id_listdef _rqvae_loss(self, vq_emb_list, res_emb_list):rqvae_loss_list = []for idx, quant in enumerate(vq_emb_list):# stop gradientloss1 = (res_emb_list[idx].detach() - quant).pow(2.0).mean()loss2 = (res_emb_list[idx] - quant.detach()).pow(2.0).mean()partial_loss = loss1 + self.loss_beta * loss2rqvae_loss_list.append(partial_loss)rqvae_loss = torch.sum(torch.stack(rqvae_loss_list))return rqvae_lossdef forward(self, data):vq_emb_list, res_emb_list, semantic_id_list = self.quantize(data)rqvae_loss = self._rqvae_loss(vq_emb_list, res_emb_list)return vq_emb_list, semantic_id_list, rqvae_lossclass RQVAE(torch.nn.Module):def __init__(self,input_dim: int,hidden_channels: list,latent_dim: int,num_codebooks: int,codebook_size: list,shared_codebook: bool,kmeans_method,kmeans_iters,distances_method,loss_beta: float,device: str,):super().__init__()self.encoder = RQEncoder(input_dim, hidden_channels, latent_dim).to(device)self.decoder = RQDecoder(latent_dim, hidden_channels[::-1], input_dim).to(device)self.rq = RQ(num_codebooks,codebook_size,latent_dim,shared_codebook,kmeans_method,kmeans_iters,distances_method,loss_beta,device,).to(device)def encode(self, x):return self.encoder(x)def decode(self, z_vq):if isinstance(z_vq, list):z_vq = z_vq[-1]return self.decoder(z_vq)def compute_loss(self, x_hat, x_gt, rqvae_loss):recon_loss = F.mse_loss(x_hat, x_gt, reduction="mean")total_loss = recon_loss + rqvae_lossreturn recon_loss, rqvae_loss, total_lossdef _get_codebook(self, x_gt):z_e = self.encode(x_gt)vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)return semantic_id_listdef forward(self, x_gt):z_e = self.encode(x_gt)vq_emb_list, semantic_id_list, rqvae_loss = self.rq(z_e)x_hat = self.decode(vq_emb_list)recon_loss, rqvae_loss, total_loss = self.compute_loss(x_hat, x_gt, rqvae_loss)return x_hat, semantic_id_list, recon_loss, rqvae_loss, total_loss

5. run.sh - 运行脚本

简单的 bash 脚本,用于启动训练程序。

run.sh 代码
#!/bin/bash# show ${RUNTIME_SCRIPT_DIR}
echo ${RUNTIME_SCRIPT_DIR}
# enter train workspace
cd ${RUNTIME_SCRIPT_DIR}# write your code below
python -u main.py

技术特点

  1. 高效注意力机制:使用 Flash Attention 优化计算效率
  2. 多模态融合:支持文本、图像等多种模态的 embedding 特征
  3. 特征工程:支持稀疏、密集、数组等多种特征类型
  4. 序列建模:同时建模用户和物品的交互序列
  5. 可扩展性:支持大规模物品库的 embedding 保存和检索

数据流程

  1. 训练阶段:读取用户序列 → 特征 embedding → Transformer 编码 → 计算正负样本 loss
  2. 推理阶段:生成用户表征 → 保存物品 embedding → 进行向量检索推荐
  3. 多模态处理:原始 embedding → RQ-VAE 压缩 → 语义 ID → 作为新特征加入模型
http://www.dtcms.com/a/309778.html

相关文章:

  • Python - 数据分析三剑客之Matplotlib
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘altair’问题
  • Apache Flink 2.1.0: 面向实时 Data + AI 全面升级,开启智能流处理新纪元
  • 从游戏NPC到手术助手:Agent AI重构多模态交互,具身智能打开AGI新大门
  • AI 重塑软件产业:从技术革命到生态重构
  • 超聚变:智能体时代,AI原生重构城企数智化基因
  • 技术信任革命:区块链重构信用机制全解析
  • 国内用户如何用手机进行YouTube直播?
  • 【CDH × Docker】一次测试部署,N 次复用的环境镜像方案
  • 9.1无法恢复的错误与 panic!
  • 基于Ascend CANN的FFmpeg与OpenCV编译指南
  • 观测云基于 ToB/ToC 业务可观测最佳实践
  • 蚂蚁开源团队发布的2025大模型开源开发生态发展情况速览
  • AI+向量化
  • 在Idea中,配置maven
  • Web安全学习步骤
  • R 语言文件读写、批量读取与图片保存实用代码汇总
  • 使用BART模型和T5模型实现文本改写
  • k8s部署mysql
  • Web学习:SQL注入之联合查询注入
  • 逻辑回归建模核心知识点梳理:原理、假设、评估指标与实战建议
  • Codeforces Round 1040 (Div. 2) E1 - E3 交互题 | 思维
  • go语言实现协程池
  • leetcode 118. 杨辉三角 简单
  • django操作orm整套
  • android MVC/MVP/MVVM/MVI架构发展历程和编写范式
  • 如何在Android中创建自定义键盘布局
  • MySQL时间处理完全指南:从存储到查询优化
  • Apache RocketMQ中 Consumer Group(消费者组)的详细说明
  • 2025新征程杯全国54校园足球锦标赛在北京世园公园隆重开幕