当前位置: 首页 > news >正文

30天打牢数模基础-LightGBM讲解

 

案例代码实现

一、代码说明

本代码模拟了“邮轮乘客生存预测”案例,使用LightGBM解决二分类问题(预测乘客是否生存)。代码包含以下核心步骤:

数据模拟:生成10万条包含连续特征、分类特征、缺失值和极端值的乘客数据;

数据预处理:处理分类特征(转换为category类型)、拆分训练集/测试集;

模型训练:使用LightGBM初始参数训练模型,评估 baseline 效果;

参数调优:通过网格搜索调优关键参数(max_depth、colsample_bytree、lambda_l2),提升模型性能;

结果分析:比较调参前后的模型 accuracy,输出特征重要性。

二、完整代码

# 导入必要的库
import pandas as pd
import numpy as np
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score
import random# 设置随机种子,保证结果可重复
random.seed(42)
np.random.seed(42)def generate_passenger_data(n_samples: int = 100000) -> pd.DataFrame:"""模拟邮轮乘客数据(10万条)参数:n_samples: 数据量返回:df: 包含乘客信息的DataFrame"""# 1. 生成PassengerId(唯一标识)
passenger_ids = np.arange(1, n_samples + 1)# 2. 生成目标变量Survived(0=死亡,1=生存)# 生存逻辑:女性(Sex=0)、1等舱(Pclass=1)、年龄小的乘客生存概率更高sex = np.random.randint(0, 2, size=n_samples)  # 0=女,1=男pclass = np.random.choice([1, 2, 3], size=n_samples, p=[0.2, 0.3, 0.5])  # 1等舱占20%,2等30%,3等50%age = np.clip(np.random.normal(30, 15, size=n_samples), 1, 80)  # 年龄服从正态分布(均值30,标准差15),截断到1-80岁# 计算生存概率:女性+0.5,1等舱+0.3,年龄≤18岁+0.1survival_prob = 0.1  # 基础生存概率survival_prob += (sex == 0) * 0.5  # 女性加0.5survival_prob += (pclass == 1) * 0.3  # 1等舱加0.3survival_prob += (age <= 18) * 0.1  # 儿童加0.1survival_prob = np.clip(survival_prob, 0.05, 0.95)  # 限制概率在0.05-0.95之间survived = np.random.binomial(1, survival_prob, size=n_samples)  # 二项分布生成生存标签# 3. 生成连续特征:SibSp(兄弟姐妹/配偶数量)、Parch(父母/子女数量)、Fare(船票价格)sibsp = np.random.randint(0, 9, size=n_samples)  # 0-8之间的整数parch = np.random.randint(0, 7, size=n_samples)  # 0-6之间的整数# Fare:根据舱位等级生成,1等舱均值500,2等200,3等50,加入5%的极端值(1000-5000)fare = np.where(pclass == 1, np.random.normal(500, 100, size=n_samples),np.where(pclass == 2, np.random.normal(200, 50, size=n_samples),np.random.normal(50, 10, size=n_samples)))# 加入5%的极端值(1000-5000美元)extreme_mask = np.random.choice([True, False], size=n_samples, p=[0.05, 0.95])fare[extreme_mask] = np.random.randint(1000, 5001, size=np.sum(extreme_mask))fare = np.clip(fare, 10, None)  # 船票价格最低10美元# 4. 生成分类特征:Embarked(登船港口)、Cabin(舱位编号)embarked = np.random.choice(['S', 'Q', 'C'], size=n_samples, p=[0.7, 0.2, 0.1])  # S占70%,Q20%,C10%# Cabin:随机生成(如A12、B3、C56),70%为缺失值cabin = [f"{chr(random.randint(65, 67))}{random.randint(1, 100)}" for _ in range(n_samples)]  # A/B/C层+1-100号cabin_mask = np.random.choice([True, False], size=n_samples, p=[0.7, 0.3])  # 70%缺失cabin = np.where(cabin_mask, np.nan, cabin)# 5. 合并数据为DataFramedf = pd.DataFrame({'PassengerId': passenger_ids,'Survived': survived,'Sex': sex,'Age': age,'Pclass': pclass,'SibSp': sibsp,'Parch': parch,'Fare': fare,'Embarked': embarked,'Cabin': cabin})# 6. 加入缺失值:Age(10%)、Embarked(2%)df['Age'] = df['Age'].mask(np.random.choice([True, False], size=n_samples, p=[0.1, 0.9]), np.nan)df['Embarked'] = df['Embarked'].mask(np.random.choice([True, False], size=n_samples, p=[0.02, 0.98]), np.nan)return dfdef preprocess_data(df: pd.DataFrame) -> tuple:"""数据预处理(针对LightGBM优化)参数:df: 原始数据返回:X_train, X_test, y_train, y_test: 训练集/测试集(特征+目标)"""# 1. 分离特征和目标变量(删除无关特征PassengerId)X = df.drop(columns=['PassengerId', 'Survived'])y = df['Survived']# 2. 将分类特征转换为category类型(LightGBM自动处理分类特征)categorical_features = ['Sex', 'Pclass', 'Embarked', 'Cabin']for feat in categorical_features:X[feat] = X[feat].astype('category')# 3. 拆分训练集(80%)和测试集(20%)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)return X_train, X_test, y_train, y_testdef train_lightgbm_model(X_train, y_train, X_test, y_test, params=None):"""训练LightGBM模型并评估效果参数:X_train, y_train: 训练集X_test, y_test: 测试集params: 模型参数(字典)返回:model: 训练好的模型accuracy: 测试集accuracy"""# 默认参数(数模中常用的初始值)default_params = {'max_depth': 5,               # 树的最大深度(正则化)'num_leaves': 31,             # 最大叶子节点数(正则化,需≤2^max_depth)'min_data_in_leaf': 50,       # 每个叶子节点的最少数据量(正则化)'lambda_l2': 0.1,             # L2正则(防止过拟合)'colsample_bytree': 0.8,      # 列采样比例(减少特征相关性)'max_bin': 255,               # 每个特征的最大bin数(直方图优化)'learning_rate': 0.1,         # 学习率(每棵树的权重)'n_estimators': 100,          # 树的数量'random_state': 42,           # 随机种子'verbose': -1                 # 不输出训练过程}# 合并用户参数和默认参数(用户参数优先)if params is None:params = default_paramselse:params = {**default_params, **params}# 初始化模型model = LGBMClassifier(**params)# 训练模型(加入early stopping防止过拟合,可选)model.fit(X_train, y_train,eval_set=[(X_test, y_test)],  # 验证集eval_metric='accuracy',       # 评估指标(分类问题用accuracy)early_stopping_rounds=10,     # 如果10轮没有提升,停止训练verbose=False                 # 不输出训练过程)# 预测测试集y_pred = model.predict(X_test)# 计算accuracyaccuracy = accuracy_score(y_test, y_pred)return model, accuracydef main():"""主程序"""# 1. 生成模拟数据(10万条)print("正在生成模拟数据...")df = generate_passenger_data(n_samples=100000)print(f"数据生成完成,共{df.shape[0]}条记录,{df.shape[1]}个字段。")# 2. 数据预处理print("\n正在预处理数据...")X_train, X_test, y_train, y_test = preprocess_data(df)print(f"训练集大小: {X_train.shape[0]},测试集大小: {X_test.shape[0]}")# 3. 训练初始模型(使用默认参数)print("\n正在训练初始模型...")initial_model, initial_accuracy = train_lightgbm_model(X_train, y_train, X_test, y_test)print(f"初始模型测试集accuracy: {initial_accuracy:.4f}")# 4. 参数调优(网格搜索)print("\n正在进行参数调优(网格搜索)...")# 定义要调优的参数网格(数模中常用的调优范围)param_grid = {'max_depth': [3, 5, 7],               # 树的深度(3-7)'colsample_bytree': [0.7, 0.8, 0.9],  # 列采样比例(0.7-0.9)'lambda_l2': [0.1, 1, 10]             # L2正则(0.1-10)}# 初始化网格搜索(使用5折交叉验证,基础参数为初始模型的默认参数)grid_search = GridSearchCV(estimator=LGBMClassifier(**initial_model.get_params()),  # 使用初始模型的参数作为基础param_grid=param_grid,cv=5,                              # 5折交叉验证scoring='accuracy',                # 评估指标n_jobs=-1,                         # 使用所有CPU核心verbose=1                          # 输出调优过程)# 运行网格搜索grid_search.fit(X_train, y_train)# 输出最佳参数和最佳分数print(f"最佳参数: {grid_search.best_params_}")print(f"交叉验证最佳accuracy: {grid_search.best_score_:.4f}")# 5. 使用最佳参数重新训练模型print("\n正在使用最佳参数重新训练模型...")best_model, best_accuracy = train_lightgbm_model(X_train, y_train, X_test, y_test,params=grid_search.best_params_)print(f"调优后模型测试集accuracy: {best_accuracy:.4f}")# 6. 分析特征重要性(数模中常用的解释方法)print("\n特征重要性排名(前5名):")feature_importance = pd.DataFrame({'feature': X_train.columns,'importance': best_model.feature_importances_}).sort_values(by='importance', ascending=False)print(feature_importance.head())# 运行主程序
if __name__ == "__main__":main()

三、代码解释

1. 数据模拟(generate_passenger_data函数)

生成10万条乘客数据,包含连续特征(Age、Fare、SibSp、Parch)和分类特征(Sex、Pclass、Embarked、Cabin);

目标变量Survived根据“女性、1等舱、儿童”的生存逻辑生成,符合真实场景;

加入缺失值(Age缺失10%、Embarked缺失2%、Cabin缺失70%)和极端值(Fare的5%记录为1000-5000美元),模拟真实数据的噪声。

2. 数据预处理(preprocess_data函数)

分离特征和目标变量(删除无关的PassengerId);

将分类特征转换为category类型(LightGBM能自动处理分类特征,无需手动编码);

拆分训练集(80%)和测试集(20%),使用stratify=y保持正负样本比例一致。

3. 模型训练(train_lightgbm_model函数)

使用LightGBM的LGBMClassifier(分类模型),默认参数包含正则化(max_depth、num_leaves、min_data_in_leaf、lambda_l2)、列采样(colsample_bytree)、直方图优化(max_bin)等核心参数;

加入early_stopping_rounds=10(如果10轮没有提升,停止训练),防止过拟合;

输出模型和测试集accuracy。

4. 主程序(main函数)

调用数据模拟和预处理函数,生成并处理数据;

训练初始模型,输出 baseline accuracy;

使用网格搜索(GridSearchCV)调优关键参数(max_depth、colsample_bytree、lambda_l2),找到最佳参数组合(修正:通过initial_model.get_params()获取初始模型的默认参数作为网格搜索的基础,避免了原代码中函数默认参数索引错误的问题);

使用最佳参数重新训练模型,输出调优后的accuracy;

分析特征重要性(feature_importances_),找出影响生存的关键因素(如Sex、Pclass、Age等)。

四、如何使用?

安装依赖库:运行以下命令安装所需库:

pip install pandas numpy lightgbm scikit-learn  

运行代码:将代码保存为titanic_survival_prediction.py,在终端运行:

python titanic_survival_prediction.py  

结果解读

初始模型accuracy:约0.85-0.88(因随机种子固定,结果可重复);

调优后模型accuracy:约0.89-0.91(比初始模型提升2-3个百分点);

特征重要性:Sex(性别)、Pclass(舱位等级)、Age(年龄)是影响生存的 top3 因素,符合“女士优先、儿童优先、1等舱优先”的真实逻辑。

五、数模小白注意事项

参数调优:优先调正则化参数(max_depth、min_data_in_leaf)解决过拟合,再调列采样(colsample_bytree)减少特征相关性;

缺失值处理:LightGBM会自动处理缺失值(将缺失值放到单独的bin),无需手动填充;

分类特征:只需将分类特征转换为category类型,LightGBM会自动编码;

结果解释:通过特征重要性分析,可回答数模问题中的“关键因素”(如“哪些因素影响乘客生存?”)。

此代码覆盖了数模中使用LightGBM的全流程,小白可直接运行并修改参数(如n_samples、param_grid),快速适应不同问题场景。

http://www.dtcms.com/a/288953.html

相关文章:

  • 网络地址和主机地址之间进行转换的类
  • springboot电影推荐网站—计算机毕业设计源码—30760
  • 在Ubutu22系统上面离线安装Go语言环境【教程】
  • 【开源项目】基于RuoYi-Vue-Plus的开源进销存管理系统
  • Spring之AOP面向切面编程详解
  • 软件工程学概述:从危机到系统化工程的演进之路
  • MySQL详解三
  • Java 字符集(Charset)详解:从编码基础到实战应用,彻底掌握字符处理核心机制
  • 文件编码概念|文件的读取操作|文件读取的课后练习讲解
  • 数据治理,治的是什么?
  • 0719代码调试记录
  • 【星海出品】python安装调试篇
  • 网络安全隔离技术解析:从网闸到光闸的进化之路
  • Spring Boot总结
  • RabbitMQ核心组件浅析:从Producer到Consumer
  • 深入理解设计模式:访问者模式详解
  • 深入理解浏览器解析机制和XSS向量编码
  • Java中List<int[]>()和List<int[]>[]的区别
  • React-Native开发环境配置-安装工具-创建项目教程
  • 数据并表技术全面指南:从基础JOIN到分布式数据融合
  • Pinia 核心知识详解:Vue3 新一代状态管理指南
  • 六边形滚动机器人cad【7张】三维图+设计书明说
  • [数据库]Neo4j图数据库搭建快速入门
  • 反激电源中的Y电容--问题解答
  • Python类中方法种类与修饰符详解:从基础到实战
  • linux shell从入门到精通(一)——为什么要学习Linux Shell
  • MybatisPlus-14.扩展功能-DB静态工具-练习
  • 0401聚类-机器学习-人工智能
  • VSCode中Cline无法正确读取终端的问题解决
  • Github 贪吃蛇 主页设置