当前位置: 首页 > news >正文

sklearn自定义pipeline的数据处理

将自定义的频数编码处理整合到sklearn的pipeline流程里面:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import PolynomialFeatures # 多项式
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
import lightgbm as lgbimport pandas as pddef load_data(path):data = pd.read_csv(path,usecols=lambda col: col != 'id')data['subscribe'] = data['subscribe'].apply(lambda x: 1 if x == 'yes' else 0,)return data# 自定义转换器1 将类别特征按频次编码
class Freqencode(BaseEstimator, TransformerMixin):def __init__(self, cat_cols=[]):self.cat_cols = cat_cols# 返回对象本身def fit(self, X, y=None):# 计算统计量return self# 转换数据def transform(self, X):# 数据转换逻辑for col in self.cat_cols:freq = X[col].value_counts(normalize=True).to_dict()X[col] = X[col].map(freq)return Xdef pipeline_model(cat_cols):pip_model = Pipeline(steps=[('freq_encode', Freqencode(cat_cols=cat_cols)),('imputer', SimpleImputer(strategy='mean')),('poly', PolynomialFeatures(degree=2, interaction_only=False, include_bias=False)),('model', lgb.LGBMClassifier(verbose=-1)),])return pip_modelif __name__ == '__main__':path = r"C:\Users\12048\Desktop\python_code\data\train.csv"data = load_data(path)# 类别特征cat_cols = list(data.select_dtypes(include=['object']).columns)x, y = data.drop(labels='subscribe', axis=1), data['subscribe']pip_model = pipeline_model(cat_cols)pip_model.fit(x, y)print('训练集表现:')prob = pip_model.predict_proba(x)[:,1]train_pred = [1 if i>0.5 else 0 for i in prob]print('混淆矩阵:\n',confusion_matrix(y, train_pred))print('模型报告:\n',classification_report(y, train_pred))print('auc:',roc_auc_score(y, prob))

相关文章:

  • stm32之USART
  • 【计算机主板架构】ATX架构
  • CN3791 锂电池充电芯片详解及电路设计要点-国产芯片
  • uniapp-商城-46-创建schema并新增到数据库
  • AI技术与园区运营的深度融合:未来生态型园区的建设路径
  • 镜头内常见的马达类型(私人笔记)
  • Python 数据分析与可视化:开启数据洞察之旅(5/10)
  • k8s之探针
  • MCP(Model Context Protocol)是专为LLM(大语言模型)应用设计的标准化协议
  • 解决 Ubuntu DNS 无法解析问题(适用于虚拟机 长期使用)
  • Spring MVC Session 属性 (@SessionAttributes) 是什么?如何使用它共享数据?
  • 信赖域策略优化TRPO算法详解:python从零实现
  • .net/C#进程间通信技术方案总结
  • 机器学习与深度学习的区别与联系:多角度详细分析
  • Linux基础(关于进程相关命令)
  • CSS opacity
  • 计算人声录音后电平的大小(dB SPL->dBFS)
  • 访问网站提示“不安全”“有风险”怎么办?
  • 3D桌面可视化开发平台HOOPS Native Platform,如何实现3D系统快速开发与部署?
  • 【网安播报】Meta 推出 LlamaFirewall开源框架以阻止 AI 越狱、注入和不安全代码
  • 伤员回归新援融入,海港逆转海牛重回争冠集团
  • 开局良好,我国第一季度广告业务收入保持较快增速
  • 保利42.41亿元竞得上海杨浦东外滩一地块,成交楼面单价超8万元
  • 铲屎官花5万带猫狗旅行,宠旅生意有多赚?
  • 国防部:奉劝有关国家不要引狼入室,甘当棋子
  • 夜读丨母亲的手擀面