当前位置: 首页 > news >正文

机器学习——K 折交叉验证(K-Fold Cross Validation),实战案例:寻找逻辑回归最佳惩罚因子C

目录

什么是交叉验证?

最常用的:K 折交叉验证(K-Fold Cross Validation)

为什么要使用交叉验证?

什么时候该用交叉验证?

代码使用:

参数详解:

实战案例:用交叉验证寻找最优惩罚因子 C

K-Fold Cross Validation 背后的原理(做了什么)

常见扩展:StratifiedKFold(保持类别比例)


什么是交叉验证?

交叉验证是一种将原始数据集划分为若干个子集,反复训练和验证模型的策略。

交叉验证(Cross-Validation)适用于你在模型调参(如逻辑回归中的 C

最常用的:K 折交叉验证(K-Fold Cross Validation)

将数据集平均分成 K 份,每次取其中 1 份做验证,剩下的 K-1 份做训练,重复 K 次,最终将 K 次的结果取平均。

图示流程(以 K=4 举例)

轮次训练集验证集
1[2,3,4][1]
2[1,3,4][2]
3[1,2,4][3]
4[1,2,3][4]

最后将 4 次的验证结果平均,得到模型在未见数据上的稳定表现。


为什么要使用交叉验证?

作用说明
✅ 稳定评估模型表现解决只依赖单一测试集带来的评估波动问题
✅ 防止过拟合多次训练验证,有助于检测模型是否泛化能力不足
✅ 用于超参数选择常用于网格搜索、正则化参数调优(如逻辑回归中的 C)

什么时候该用交叉验证?

场景是否推荐使用交叉验证
数据量较小✅ 强烈建议
不平衡分类问题✅ 建议配合 StratifiedKFold
模型调参(如 C、k、深度)✅ 必用
数据量极大(上百万)❌ 考虑分批验证或子集评估

代码使用:

from sklearn.model_selection import cross_val_scorecross_val_score(estimator, X, y=None, *, scoring=None, cv=None,n_jobs=None, verbose=0, fit_params=None, pre_dispatch='2*n_jobs',error_score=np.nan)

参数详解:

参数名类型说明
estimator模型对象

要评估的模型,例如 LogisticRegression()RandomForestClassifier()

‘model = LogisticRegression()’后直接传入‘model’即可

Xndarray / DataFrame特征数据集
yarray-like目标标签(监督学习必须)
scoringstr 或 callable指定评估指标(如 accuracy, recall, f1, roc_auc 等)
cvint 或 交叉验证对象交叉验证折数,如 cv=5;或 StratifiedKFold, KFold 等对象
n_jobsint并行执行的任务数:-1 使用所有核心,1 表示不并行
verboseint控制打印的详细程度(0为不输出,越大越详细)
fit_paramsdict要传递给 estimator.fit() 的额外参数(少用)
pre_dispatchstr控制预分发任务数,默认 '2*n_jobs',通常无需改动
error_score‘raise’ 或 float出错时返回分数,或抛异常。默认返回 NaN


实战案例:用交叉验证寻找最优惩罚因子 C

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression  # 导入逻辑回归模型
from sklearn.model_selection import train_test_split, cross_val_score  # 用于数据拆分和交叉验证
from sklearn.preprocessing import StandardScaler  # 用于数据标准化处理
from sklearn import metrics  # 用于模型评估指标计算data = pd.read_csv('creditcard.csv')# 初始化标准化器,对交易金额(Amount)进行标准化处理
scaler = StandardScaler()
data['Amount'] = scaler.fit_transform(data[['Amount']])     # 准备特征数据X(排除时间和目标变量)和目标变量y(欺诈标签,1表示欺诈,0表示正常)
X = data.drop(["Time","Class"], axis=1)
y = data.Class# 将数据拆分为训练集(70%)和测试集(30%),设置随机种子保证结果可复现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)#----------------------------------------------------------------------------------------
# 以下部分用于寻找最优的正则化参数C
c_range = [0.01, 0.1, 1, 10, 100]   # 定义要尝试的正则化参数C的取值范围(C越小,正则化强度越大)
scores = []   # 存储不同C值对应的交叉验证平均召回率
for c in c_range:model = LogisticRegression(C=c, penalty='l2', solver='lbfgs', max_iter=1000)   # 初始化逻辑回归模型,指定正则化参数C、L2正则化、求解器和最大迭代次数score = cross_val_score(model, X_train, y_train, cv=8, scoring='recall')   # 使用8折交叉验证,计算模型在训练集上的召回率recallscore_mean = sum(score) / len(score)   # 计算交叉验证召回率的平均值scores.append(score_mean)   # 将平均召回率添加到列表中print(score_mean)
# 找到最大平均召回率对应的C值,作为最优惩罚因子
best_c = c_range[np.argmax(scores)]   #argmax返回数组中最大值所在的索引位置
print(f'最优惩罚因子为:{best_c}')
#----------------------------------------------------------------------------------------# 使用最优惩罚因子训练最终的逻辑回归模型
model = LogisticRegression(C=best_c, penalty='l2', solver='lbfgs')
model.fit(X_train, y_train)

K-Fold Cross Validation 背后的原理(做了什么)

cross_val_score(model, X, y, cv=8) 等价于以下操作:

  1. 将数据按 8 等份分割

  2. 第一次拿前 7 份训练,第 8 份验证 → 计算指标

  3. 第二次拿 1,2,3,4,5,6,8 训练,第 7 份验证 → 计算指标

  4. ...

  5. 得到 8 个指标结果,返回组成数组

自动完成了分割、训练、预测和评分


常见扩展:StratifiedKFold(保持类别比例)

对于不平衡数据(如欺诈检测),StratifiedKFold 是更合适的选择,它能在每一折中保持正负样本比例一致。

from sklearn.model_selection import StratifiedKFoldskf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X_train, y_train, cv=skf, scoring='recall')

http://www.dtcms.com/a/311574.html

相关文章:

  • 深入理解C++中的vector容器
  • VS2019安装HoloLens 没有设备选项
  • 大模型(五)MOSS-TTSD学习
  • 二叉树的层次遍历 II
  • 算法: 字符串part02: 151.翻转字符串里的单词 + 右旋字符串 + KMP算法28. 实现 strStr()
  • Redis数据库存储键值对的底层原理
  • 信创应用服务器TongWeb安装教程、前后端分离应用部署全流程
  • Web API安全防护全攻略:防刷、防爬与防泄漏实战方案
  • Dispersive Loss:为生成模型引入表示学习 | 如何分析kaiming新提出的dispersive loss,对扩散模型和aigc会带来什么影响?
  • 二、无摩擦刚体捉取——抗力旋量捉取
  • uniapp 数组的用法
  • 【c#窗体荔枝计算乘法,两数相乘】2022-10-6
  • Python Pandas.from_dummies函数解析与实战教程
  • 【语音技术】什么是动态实体
  • 【解决错误】IDEA启动SpringBoot项目 出现:Command line is too long
  • 5734 孤星
  • process_vm_readv/process_vm_writev 接口详解
  • 如何在 Ubuntu 24.04 或 22.04 LTS Linux 上安装 Guake 终端应用程序
  • Next.js 怎么使用 Chakra UI
  • LINUX82 shell脚本变量分类;系统变量;变量赋值;四则运算;shell
  • 落霞归雁·思维框架
  • 队列的使用【C++】
  • 【王阳明代数讲义】基本名词解释
  • InfluxDB 与 Node.js 框架:Express 集成方案(一)
  • 【RK3568 RTC 驱动开发详解】
  • 操作系统-lecture5(线程)
  • Terraria 服务端部署(Docker)
  • Trae + Notion MCP:将你的Notion数据库升级为智能对话机器人
  • 自动驾驶中的传感器技术14——Camera(5)
  • C#开发入门指南_学习笔记