当前位置: 首页 > news >正文

视频推荐模型代码解析(马栏山芒果TV算法大赛)

整体架构概述

这段代码实现了一个基于LightGBM的视频推荐系统,主要用于预测用户对视频的点击概率。整体流程分为五个核心阶段:

  1. 数据加载与预处理:读取多源数据并进行内存优化
  2. 特征工程:构建用户、视频及交互特征
  3. 样本准备:生成训练所需的标签数据
  4. 模型训练:使用LightGBM训练二分类模型
  5. 预测部署:对新数据进行预测并生成推荐结果

核心库依赖说明

import pandas as pd # 数据处理核心库 import numpy as np # 数值计算库 import lightgbm as lgb # 梯度提升树模型库 from sklearn.model_selection import train_test_split # 数据集拆分 from sklearn.metrics import roc_auc_score # 模型评估指标 import chardet # 文件编码检测 import os # 文件系统操作 import gc # 内存管理 import joblib # 模型保存与加载 from tqdm import tqdm # 进度条显示 import warnings # 警告处理

关键函数解析

1. 内存优化函数 reduce_mem_usage

def reduce_mem_usage(df, use_float16=False):  """迭代降低DataFrame的内存占用"""  start_mem = df.memory_usage().sum() / 1024**2  print(f"内存优化前: {start_mem:.2f} MB")   for col in df.columns:  col_type = df[col].dtype  if col_type != object: # 不处理字符串类型  c_min = df[col].min()  c_max = df[col].max()   # 整数类型优化  if str(col_type)[:3] == "int":  if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:  df[col] = df[col].astype(np.int8)  elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:  df[col] = df[col].astype(np.int16)  # ... 其他整数类型判断   # 浮点数类型优化  else:  if use_float16 and c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:  df[col] = df[col].astype(np.float16)  elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:  df[col] = df[col].astype(np.float32)   end_mem = df.memory_usage().sum() / 1024**2  print(f"内存优化后: {end_mem:.2f} MB ({100*(start_mem-end_mem)/start_mem:.1f}% 减少)")  return df

功能说明
通过将数据类型从高精度(如int64、float64)转换为适合数据范围的低精度类型(如int8、float32),显著减少内存占用。这对处理大型数据集至关重要,能避免内存溢出并提高运算速度。

2. 数据加载函数 load_data_for_day

def load_data_for_day(day):  """逐天加载数据并进行基本处理,返回优化后的DataFrame"""  dtypes = {'did': 'category', 'vid': 'category'} # 将ID类特征设为category类型  day_str = f"{day:02d}"   # 加载曝光数据  see_path = f'see_{day_str}.csv'  see = pd.read_csv(see_path, encoding='latin1', dtype=dtypes)   # 加载点击数据(如存在)  click_path = f'click_{day_str}.csv'  if os.path.exists(click_path):  click = pd.read_csv(click_path, encoding='ISO-8859-1', on_bad_lines='skip', dtype=dtypes)  click = click[['did', 'vid']] # 只保留需要的列  click['clicked'] = 1 # 添加点击标记   # 加载播放数据(如存在)  play_path = f'playplus_{day_str}.csv'  if os.path.exists(play_path):  play = pd.read_csv(... , dtype=dtypes)  play = play[['did', 'vid', 'play_time']] # 保留播放时长   return see, click, play

核心处理逻辑

  • 按天加载三种类型数据:曝光数据(see)、点击数据(click)、播放数据(play)
  • 使用category类型存储用户ID(did)和视频ID(video),减少内存占用
  • 对缺失文件和字段进行容错处理,提高代码健壮性

3. 特征工程函数 process_data_in_chunks

def process_data_in_chunks(days, feature_builder=None):  """分块处理数据,避免内存溢出"""  # 加载视频基础 information  video_info = pd.read_csv('vid_info_table.csv', encoding='gbk', dtype={'vid': 'category'})   # 初始化用户和视频统计字典  user_stats = {} # 存储用户行为统计  video_stats = {} # 存储视频特征统计   # 逐天处理数据  for day in tqdm(range(1, days + 1), desc="处理每日数据"):  see, click, play = load_data_for_day(day)  if see is None: continue   # 合并多源数据  see = pd.merge(see, play, on=['did', 'vid'], how='left') # 合并播放数据  see = pd.merge(see, click, on=['did', 'vid'], how='left') # 合并点击数据  see = pd.merge(see, video_info[['vid', 'item_duration']], on='vid', how='left') # 合并视频时长   # 计算完成率特征  see['completion_rate'] = (see['play_time'] / see['item_duration']).clip(0, 1)   # 创建标签(核心业务逻辑)  see['label'] = np.select(  [(see['completion_rate'] > 0.4), (see['clicked'] == 1)],  [2, 1], # 2=完成观看, 1=点击未完成, 0=曝光未点击  default=0  )  see['binary_label'] = see['label'].apply(lambda x: 1 if x >= 1 else 0) # 二分类标签   # 更新用户统计(构建用户特征)  for _, row in see.iterrows():  did = row['did']  if did not in user_stats:  user_stats[did] = {'exposure_count': 0, 'click_count': 0, 'active_days': set()}  user_stats[did]['exposure_count'] += 1 # 曝光次数  if row['clicked'] == 1:  user_stats[did]['click_count'] += 1 # 点击次数  user_stats[did]['active_days'].add(day) # 活跃天数   # 更新视频统计(构建视频特征)  # ... 类似用户统计逻辑   # 生成用户特征DataFrame  user_features = []  for did, stats in user_stats.items():  user_features.append({  'did': did,  'user_click_rate': stats['click_count'] / stats['exposure_count'], # 点击率  'user_active_days': len(stats['active_days']) # 活跃天数  })   # 生成视频特征DataFrame  # ... 类似用户特征生成逻辑   return user_df, video_df

核心特征说明

  • 用户特征:点击率(user_click_rate)、活跃天数(user_active_days)
  • 视频特征:视频流行度(video_popularity)
  • 交互特征:用户-视频交互(user_video_interaction)、用户-视频亲和力(user_video_affinity)

4. 模型训练函数 train_model

def train_model(samples):  """训练LightGBM模型"""  # 选择特征列  features = ['user_click_rate', 'video_popularity', 'user_active_days',  'user_video_interaction', 'user_video_affinity']  X = samples[features]  y = samples['binary_label'] # 二分类标签   # 划分训练集和验证集  X_train, X_test, y_train, y_test = train_test_split(  X, y, test_size=0.2, random_state=42, stratify=y  )   # 配置LightGBM参数  params = {  'boosting_type': 'gbdt', # 梯度提升决策树  'objective': 'binary', # 二分类任务  'metric': 'auc', # 评估指标:AUC  'num_leaves': 31, # 叶子节点数  'max_depth': 7, # 树深度  'learning_rate': 0.05, # 学习率  'feature_fraction': 0.7, # 特征采样比例  'bagging_fraction': 0.8, # 样本采样比例  'bagging_freq': 5, # 采样频率  'seed': 42 # 随机种子,保证结果可复现  }   # 训练模型  lgb_train = lgb.Dataset(X_train, y_train)  lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)   model = lgb.train(  params,  lgb_train,  num_boost_round=500, # 最大迭代次数  valid_sets=[lgb_train, lgb_eval],  callbacks=[  early_stopping(stopping_rounds=50), # 早停策略,防止过拟合  log_evaluation(period=100) # 日志输出频率  ]  )   # 评估模型  y_pred = model.predict(X_test)  auc_score = roc_auc_score(y_test, y_pred)  print(f"验证集AUC: {auc_score:.4f}")   # 保存模型  joblib.dump(model, 'lightgbm_model.pkl')  return model, features, auc_score

模型训练关键技术

  • 早停策略(early_stopping):当验证集指标不再提升时停止训练,防止过拟合
  • 分层抽样(stratify=y):保持训练集和验证集的标签分布一致
  • 特征采样与样本采样:提高模型泛化能力,减少过拟合风险

5. 预测函数 predict_new_data

def predict_new_data(model, feature_columns, test_file):  """对新数据进行预测"""  # 加载测试数据  test_data = pd.read_csv(test_file, dtype={'did': 'category', 'vid': 'category'})   # 加载用户和视频特征映射表  user_df = pd.read_csv('user_click_rate.csv')  video_df = pd.read_csv('video_popularity.csv')   # 创建特征映射字典(优化查询速度)  user_click_map = user_df.set_index('did')['user_click_rate'].to_dict()  video_pop_map = video_df.set_index('vid')['video_popularity'].to_dict()   # 为测试数据添加特征  test_data['user_click_rate'] = test_data['did'].map(user_click_map).fillna(global_user_rate)  test_data['video_popularity'] = test_data['vid'].map(video_pop_map).fillna(global_video_pop)  # ... 其他特征处理   # 分批预测(避免内存溢出)  batch_size = 100000  predictions = []  for i in tqdm(range(0, len(test_features), batch_size), desc="预测批次"):  batch = test_features.iloc[i:i+batch_size]  preds = model.predict(batch)  predictions.extend(preds.tolist())   # 生成推荐结果(每个用户推荐一个视频)  test_data['click_prob'] = predictions  top_predictions = test_data.sort_values('click_prob', ascending=False).groupby('did').head(1)  result = top_predictions[['did', 'vid', 'click_prob']]  result.to_csv('prediction_result.csv', index=False)   return result

预测阶段优化

  • 使用字典映射快速查找用户/视频特征
  • 分批预测处理大规模测试数据
  • 按用户分组取Top1推荐结果

主程序执行流程 

if __name__ == '__main__':  try:  # 1. 准备训练样本(使用7天数据)  samples, _, _ = prepare_samples(days=7)   # 2. 训练模型  model, features, auc_score = train_model(samples)   # 3. 预测新数据  result = predict_new_data(model, features, 'testA_did_show.csv')   print("✅ 流程成功完成!")  except Exception as e:  print(f"❌ 流程出错: {str(e)}")  traceback.print_exc()

关键技术点总结

1. 内存优化策略

  • 使用category类型存储ID类特征
  • 数值类型降精度转换(int64→int8, float64→float32)
  • 分块处理数据,避免一次性加载全部数据
  • 显式调用gc.collect()释放内存

2. 特征工程思路

  • 用户维度:点击率、活跃天数
  • 视频维度:流行度(点击用户数)
  • 交互维度:用户活跃度×视频流行度、用户点击率×视频流行度
  • 行为标签:基于播放完成率(>40%)和点击行为构建多分类标签

3. 模型优化技巧

  • 早停策略防止过拟合
  • 特征和样本采样提高泛化能力
  • 控制树深度和叶子节点数降低模型复杂度
  • 使用AUC作为评估指标(适合二分类问题)

代码改进建议

  1. 特征扩展:可添加时间特征(如周末/工作日、时段)、用户兴趣偏好等
  2. 交叉验证:使用K-fold交叉验证替代简单的train-test split
  3. 参数调优:结合GridSearchCV或Optuna进行超参数优化
  4. 特征重要性:分析模型特征重要性,移除冗余特征
  5. 异常值处理:增加对异常播放时长、极端点击率的处理逻辑
http://www.dtcms.com/a/278339.html

相关文章:

  • 从代码学习深度学习 - 自然语言推断:微调BERT PyTorch版
  • Cesium 9 ,Cesium 离线地图本地实现与服务器部署( Vue + Cesium 多项目共享离线地图切片部署实践 )
  • H264的帧内编码和帧间编码
  • 2025年睿抗机器人开发者大赛CAIP-编程技能赛本科组(省赛)解题报告 | 珂学家
  • Python 变量与简单输入输出:从零开始写你的第一个交互程序
  • 【Java入门到精通】(四)Java语法进阶
  • 动手学深度学习——线性回归的从零开始实现
  • 【记录】BLE|百度的旧蓝牙随身音箱手机能配对不能连接、电脑能连接不能使用的解决思路(Wireshark捕获并分析手机蓝牙报文)
  • 1.2.2 高级特性详解——AI教你学Django
  • 【图片识别改名】水印相机拍的照片如何将照片的名字批量改为水印内容?图片识别改名的详细步骤和注意事项
  • 【WPF】WPF 自定义控件 实战详解,含命令实现
  • 【零基础入门unity游戏开发——unity3D篇】3D光源之——unity6的新功能Adaptive Probe Volumes(APV)(自适应探针体积)
  • ACL流量控制实验
  • 深入了解linux系统—— 进程信号的产生
  • 客户端主机宕机,服务端如何处理 TCP 连接?详解
  • EasyExcel实现Excel文件导入导出
  • VScode链接服务器一直卡在下载vscode服务器,无法连接成功
  • C++之哈希表的基本介绍以及其自我实现(开放定址法版本)
  • 多客户端 - 服务器结构-实操
  • 史上最清楚!读者,写者问题(操作系统os)
  • 基于 Gitlab、Jenkins与Jenkins分布式、SonarQube 、Nexus 的 CiCd 全流程打造
  • SQL创建三个表
  • 从 JSON 到 Python 对象:一次通透的序列化与反序列化之旅
  • Dubbo高阶难题:异步转同步调用链上全局透传参数的丢失问题
  • Selenium动态网页爬虫编写与解释
  • 【微信小程序】
  • 当你在 Git 本地提交后,因权限不足无法推送到服务端,若想撤销本次提交,可以根据不同的需求选择合适的方法,下面为你介绍两种常见方式。
  • 清除 Android 手机 SIM 卡数据的4 种简单方法
  • 云手机常见问题解析:解决延迟、掉线等困扰
  • 云手机的多重用途:从游戏挂机到办公自动化