当前位置: 首页 > news >正文

使用 Bank Churn 数据集进行二元分类

一、前言

分类任务:预测客户是继续使用其帐户还是关闭帐户(例如,流失)
项目地址:https://www.kaggle.com/competitions/playground-series-s4e1

二、具体步骤

(一)数据导入与预览

import pandas as pd 
import numpy as np
import matplotlib.pylab as plt 
import seaborn as sns 
from sklearn.model_selection import StratifiedKFold
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import roc_auc_scoretrain = pd.read_csv('train.csv', index_col='id')
test = pd.read_csv('test.csv', index_col='id')
train.head(5)
CustomerIdSurnameCreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExited
id
015674932Okwudilichukwu668FranceMale33.030.0021.00.0181449.970
115749177Okwudiliolisa627FranceMale33.010.0021.01.049503.500
215694510Hsueh678FranceMale40.0100.0021.00.0184866.690
315741417Kao581FranceMale34.02148882.5411.01.084560.880
415766172Chiemenam716SpainMale33.050.0021.01.015068.830
train.info()
<class 'pandas.core.frame.DataFrame'>
Index: 165034 entries, 0 to 165033
Data columns (total 13 columns):#   Column           Non-Null Count   Dtype  
---  ------           --------------   -----  0   CustomerId       165034 non-null  int64  1   Surname          165034 non-null  object 2   CreditScore      165034 non-null  int64  3   Geography        165034 non-null  object 4   Gender           165034 non-null  object 5   Age              165034 non-null  float646   Tenure           165034 non-null  int64  7   Balance          165034 non-null  float648   NumOfProducts    165034 non-null  int64  9   HasCrCard        165034 non-null  float6410  IsActiveMember   165034 non-null  float6411  EstimatedSalary  165034 non-null  float6412  Exited           165034 non-null  int64  
dtypes: float64(5), int64(5), object(3)
memory usage: 17.6+ MB
train.drop('CustomerId', axis=1).nunique()
Surname             2797
CreditScore          457
Geography              3
Gender                 2
Age                   71
Tenure                11
Balance            30075
NumOfProducts          4
HasCrCard              2
IsActiveMember         2
EstimatedSalary    55298
Exited                 2
dtype: int64

绘制分布图 ———— 从总体上熟悉数据,也可以直观发现数据的异常点

# 绘制分布图 ———— 从总体上熟悉数据
plt.figure(figsize=(18, 12))
for i, column in enumerate(train.drop(columns=['CustomerId', 'Surname']).columns, 1):plt.subplot(len(train.columns)//3 + 1, 3, i)sns.histplot(train[column])plt.title(column)plt.tight_layout()
plt.show()

在这里插入图片描述

(二)特征工程

# 特征工程
def create_feature(df): df['age_bin'] = pd.cut(df['Age'], bins=[0, 20, 40, 60, 80, 110], labels=['<20', '20-40', '40-60', '60-80', '>=80'])df['is_low_score'] = df['CreditScore'].apply(lambda x: 1 if x < 500 else 0)df['is_senior'] = df['Age'].apply(lambda x: 1 if x >= 60 else 0)return dftrain = create_feature(train)
test = create_feature(test)# 分离特征与目标变量
X = train.drop(columns=['Exited'])
y = train['Exited']
cat_features = X.select_dtypes(['object', 'category']).columns.tolist()

(三)模型构建与评估

方法:GPU加速计算 + 5折交叉验证
超参数优化:未进行详细优化

# 建模与模型评估
n_fold = 5
folds = StratifiedKFold(n_splits=n_fold, random_state=42, shuffle=True)    # 注意类别不平衡
auc_valids = []    # 储存每个 fold 的 AUC
y_pred = np.empty((n_fold, len(test)))    # 存储每个fold的test预测值for fold, (train_index, valid_index) in enumerate(folds.split(X, y)): # 训练集X_train, y_train = X.iloc[train_index], y.iloc[train_index]# 测试集X_vilid, y_vilid = X.iloc[valid_index], y.iloc[valid_index]# 模型训练train_pool = Pool(X_train, y_train, cat_features=cat_features)valid_pool = Pool(X_vilid, y_vilid, cat_features=cat_features)clf = CatBoostClassifier(eval_metric='AUC',     # 评估指标task_type='GPU', learning_rate=0.02, iterations=5000)clf.fit(train_pool, eval_set=valid_pool, verbose=500)# 验证集测试y_pred_valid = clf.predict_proba(X_vilid)[:,1]    # 完整概率示例:[[0.2 0.8],而[:,1]表示正类的概率,即[0.8]auc_valid = roc_auc_score(y_vilid, y_pred_valid)print(f'Fold {fold} AUC: {auc_valid}')auc_valids.append(auc_valid)# 用不同fold训练的模型预测测试集y_pred_test = clf.predict_proba(test)[:, 1]y_pred[fold, :] = y_pred_testprint('-'*60)print(f'Mean AUC: {np.mean(auc_valids): .4f}')
Default metric period is 5 because AUC is/are not implemented for GPU
0:	test: 0.8739918	best: 0.8739918 (0)	total: 70.7ms	remaining: 5m 53s
500:	test: 0.8944888	best: 0.8944888 (500)	total: 10s	remaining: 1m 29s
1000:	test: 0.8952877	best: 0.8952985 (980)	total: 20.5s	remaining: 1m 21s
1500:	test: 0.8954829	best: 0.8954870 (1490)	total: 31.2s	remaining: 1m 12s
2000:	test: 0.8955086	best: 0.8955306 (1625)	total: 42s	remaining: 1m 2s
2500:	test: 0.8954988	best: 0.8955306 (1625)	total: 52.9s	remaining: 52.9s
3000:	test: 0.8954397	best: 0.8955306 (1625)	total: 1m 3s	remaining: 42.5s
3500:	test: 0.8953933	best: 0.8955306 (1625)	total: 1m 14s	remaining: 32s
4000:	test: 0.8952920	best: 0.8955306 (1625)	total: 1m 25s	remaining: 21.4s
4500:	test: 0.8952608	best: 0.8955306 (1625)	total: 1m 37s	remaining: 10.8s
4999:	test: 0.8951928	best: 0.8955306 (1625)	total: 1m 48s	remaining: 0us
bestTest = 0.8955305815
bestIteration = 1625
Shrink model to first 1626 iterations.
Fold 0 AUC: 0.8955305898663352
------------------------------------------------------------
Default metric period is 5 because AUC is/are not implemented for GPU
0:	test: 0.8743308	best: 0.8743308 (0)	total: 21.1ms	remaining: 1m 45s
500:	test: 0.8945833	best: 0.8945833 (500)	total: 9.9s	remaining: 1m 28s
1000:	test: 0.8954305	best: 0.8954305 (1000)	total: 20.3s	remaining: 1m 21s
1500:	test: 0.8957238	best: 0.8957241 (1485)	total: 31.1s	remaining: 1m 12s
2000:	test: 0.8957269	best: 0.8957676 (1760)	total: 41.8s	remaining: 1m 2s
2500:	test: 0.8956528	best: 0.8957676 (1760)	total: 52.7s	remaining: 52.7s
3000:	test: 0.8956784	best: 0.8957676 (1760)	total: 1m 3s	remaining: 42.5s
3500:	test: 0.8956079	best: 0.8957676 (1760)	total: 1m 14s	remaining: 32s
4000:	test: 0.8955335	best: 0.8957676 (1760)	total: 1m 25s	remaining: 21.4s
4500:	test: 0.8954549	best: 0.8957676 (1760)	total: 1m 37s	remaining: 10.8s
4999:	test: 0.8953199	best: 0.8957676 (1760)	total: 1m 48s	remaining: 0us
bestTest = 0.8957676291
bestIteration = 1760
Shrink model to first 1761 iterations.
Fold 1 AUC: 0.8957676119974756
------------------------------------------------------------
Default metric period is 5 because AUC is/are not implemented for GPU
0:	test: 0.8746601	best: 0.8746601 (0)	total: 24.1ms	remaining: 2m
500:	test: 0.8953991	best: 0.8953991 (500)	total: 9.89s	remaining: 1m 28s
1000:	test: 0.8964818	best: 0.8964818 (1000)	total: 20.5s	remaining: 1m 21s
1500:	test: 0.8968892	best: 0.8968893 (1495)	total: 31.2s	remaining: 1m 12s
2000:	test: 0.8971167	best: 0.8971181 (1965)	total: 42.1s	remaining: 1m 3s
2500:	test: 0.8972195	best: 0.8972385 (2350)	total: 53.2s	remaining: 53.1s
3000:	test: 0.8972575	best: 0.8972622 (2985)	total: 1m 4s	remaining: 42.8s
3500:	test: 0.8972616	best: 0.8972861 (3280)	total: 1m 15s	remaining: 32.3s
4000:	test: 0.8972484	best: 0.8972864 (3850)	total: 1m 26s	remaining: 21.6s
4500:	test: 0.8972290	best: 0.8972864 (3850)	total: 1m 37s	remaining: 10.8s
4999:	test: 0.8972137	best: 0.8972864 (3850)	total: 1m 49s	remaining: 0us
bestTest = 0.8972864151
bestIteration = 3850
Shrink model to first 3851 iterations.
Fold 2 AUC: 0.8972863638690577
------------------------------------------------------------
Default metric period is 5 because AUC is/are not implemented for GPU
0:	test: 0.8753492	best: 0.8753492 (0)	total: 22.1ms	remaining: 1m 50s
500:	test: 0.8957655	best: 0.8957655 (500)	total: 9.91s	remaining: 1m 28s
1000:	test: 0.8964766	best: 0.8964805 (990)	total: 20.5s	remaining: 1m 22s
1500:	test: 0.8967288	best: 0.8967288 (1500)	total: 31.3s	remaining: 1m 12s
2000:	test: 0.8967985	best: 0.8968331 (1860)	total: 42s	remaining: 1m 2s
2500:	test: 0.8968449	best: 0.8968576 (2205)	total: 52.9s	remaining: 52.8s
3000:	test: 0.8968754	best: 0.8968790 (2935)	total: 1m 3s	remaining: 42.5s
3500:	test: 0.8968456	best: 0.8968830 (3010)	total: 1m 14s	remaining: 32s
4000:	test: 0.8968126	best: 0.8968830 (3010)	total: 1m 25s	remaining: 21.5s
4500:	test: 0.8967692	best: 0.8968830 (3010)	total: 1m 37s	remaining: 10.8s
4999:	test: 0.8966937	best: 0.8968830 (3010)	total: 1m 48s	remaining: 0us
bestTest = 0.8968829513
bestIteration = 3010
Shrink model to first 3011 iterations.
Fold 3 AUC: 0.8968829799706398
------------------------------------------------------------
Default metric period is 5 because AUC is/are not implemented for GPU
0:	test: 0.8713666	best: 0.8713666 (0)	total: 32.9ms	remaining: 2m 44s
500:	test: 0.8930401	best: 0.8930401 (500)	total: 9.91s	remaining: 1m 28s
1000:	test: 0.8936611	best: 0.8936611 (1000)	total: 20.8s	remaining: 1m 23s
1500:	test: 0.8937590	best: 0.8937859 (1345)	total: 33.2s	remaining: 1m 17s
2000:	test: 0.8937768	best: 0.8938069 (1880)	total: 46.2s	remaining: 1m 9s
2500:	test: 0.8936644	best: 0.8938069 (1880)	total: 58.7s	remaining: 58.7s
3000:	test: 0.8936239	best: 0.8938069 (1880)	total: 1m 10s	remaining: 47.1s
3500:	test: 0.8935375	best: 0.8938069 (1880)	total: 1m 21s	remaining: 35.1s
4000:	test: 0.8934290	best: 0.8938069 (1880)	total: 1m 33s	remaining: 23.3s
4500:	test: 0.8933018	best: 0.8938069 (1880)	total: 1m 44s	remaining: 11.6s
4999:	test: 0.8932022	best: 0.8938069 (1880)	total: 1m 56s	remaining: 0us
bestTest = 0.8938069344
bestIteration = 1880
Shrink model to first 1881 iterations.
Fold 4 AUC: 0.8938069232633625
------------------------------------------------------------
Mean AUC:  0.8959

(四)特征重要性可视化

import shapshap.initjs()
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(train_pool)# 特征重要性条形图
shap.summary_plot(shap_values, X_train, plot_type="bar")

在这里插入图片描述

说明:从特征重要性结果看,在前面特征工程构建的两个新特征并不重要,可以将其删除,避免使得模型变得复杂或引入噪声。特征工程构造的新变量,是需要经过验证最好,当然了,一些模型会自动进行特征选择,也可能不需要。

(五)结果保存

# 保存结果
y_pred_mean = y_pred.mean(axis=0)
submission = pd.DataFrame({'id': test.index.tolist(), 'Exited': y_pred_mean
})
submission.to_csv('submission.csv', index=False)

kaggle上提交测试集结果:
在这里插入图片描述

相关文章:

  • 唐山app开发日照网站优化公司
  • 基于php做的网站下载百度客户端电脑版下载
  • 网站建设前分析网络推广平台收费不便宜
  • WordPress自建图床API西安seo计费管理
  • wordpress-5.6.20下载关键词排名优化公司
  • 做淘宝客网站难吗seo01网站
  • 字节跳动开源了一款 Deep Research 项目
  • react生命周期及hooks等效实现
  • Windows 创建并激活 Python 虚拟环境venv
  • 华为云Flexus+DeepSeek征文 | 基于CCE容器的AI Agent高可用部署架构与弹性扩容实践
  • 解决Fedora21下无法使用NWJS网页透明效果的问题
  • OSS监控体系搭建:Prometheus+Grafana实时监控流量、错误码、存储量(开源方案替代云监控自定义视图)
  • 学习threejs,使用kokomi、gsap实现图片环效果
  • 独家战略!谷子科技“芯”技术联姻浙江卫视
  • 跟着Carl学算法--哈希表
  • Kafka如何保证消息可靠?
  • 构建你的 AI 模块宇宙:Spring AI MCP Server 深度定制指南
  • 哈希表理论与算法总结
  • TCP/UDP协议深度解析(一):UDP特性与TCP确认应答以及重传机制
  • Leaking GAN
  • Netty内存池核心PoolArena源码解析
  • 搭建智能问答系统,有哪些解决方案,比如使用Dify,LangChain4j+RAG等
  • 《C++初阶之类和对象》【初始化列表 + 自定义类型转换 + static成员】
  • Python光学玻璃库opticalglass
  • IP证书在网络安全中的作用
  • Windows驱动开发最新教程笔记2025(一)名词解释