当前位置: 首页 > news >正文

逻辑回归的多分类实战:以鸢尾花数据集为例

文章目录

  • 引言:从二分类到多分类
  • 一、多分类问题无处不在
  • 二、One-vs-All策略揭秘
    • 1. 核心思想
    • 2. 数学表达
  • 三、鸢尾花分类完整实现
    • 1. 环境准备
    • 2. 数据加载与探索
    • 3. 数据预处理
    • 4. 模型训练与评估
    • 5. 决策边界可视化
  • 四、关键参数解析
  • 五、总结

引言:从二分类到多分类

  • 逻辑回归是机器学习中最基础也最重要的算法之一,但初学者常常困惑:逻辑回归明明是二分类算法,如何能处理多分类问题呢?本文将带你深入了解逻辑回归的多分类策略,并通过完整的鸢尾花分类代码实现。

一、多分类问题无处不在

  • 在我们的日常生活和工作中,多分类问题比比皆是:
  • 邮件分类:工作(y=1)、朋友(y=2)、家庭(y=3)、爱好(y=4)
  • 天气预测:晴天(y=1)、多云(y=2)、雨天(y=3)、雪天(y=4)
  • 医疗诊断:健康(y=1)、感冒(y=2)、流感(y=3)

这些场景都需要算法能够区分多个类别,而逻辑回归通过巧妙的扩展就能胜任这些任务。

二、One-vs-All策略揭秘

1. 核心思想

  • One-vs-All(一对多,也称为One-vs-Rest)策略将多分类问题转化为多个二分类问题:
  1. 对于N个类别,训练N个独立的二分类器
  2. 第i个分类器将第i类作为正类,其余所有类别作为负类
  3. 预测时,选择所有分类器中预测概率最高的类别
    在这里插入图片描述
    在这里插入图片描述

2. 数学表达

对于第i类,我们的假设函数为:
h θ ( i ) ( x ) = P ( y = i ∣ x ; θ ) h_\theta^{(i)}(x) = P(y = i|x;\theta) hθ(i)(x)=P(y=ix;θ)

预测时选择:
max ⁡ i h θ ( i ) ( x ) \max_i h_\theta^{(i)}(x) imaxhθ(i)(x)

三、鸢尾花分类完整实现

  • 使用Python和scikit-learn库完整实现鸢尾花的多分类任务。

1. 环境准备

# 导入必要的库
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score,confusion_matrix,classification_report)
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import StandardScaler

2. 数据加载与探索

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 或 'Microsoft YaHei'
plt.rcParams['axes.unicode_minus'] = False# 加载鸢尾花数据集
iris = load_iris()
# 特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# sepal length:花萼长度 sepal width:花萼宽度 petal length: 花瓣长度 petal width: 花瓣宽度
X = iris.data  # 特征矩阵 (150, 4)
# 目标类别: ['setosa' 'versicolor' 'virginica']
# setosa 山鸢尾 versicolor 变色鸢尾 virginica 维吉尼亚鸢尾
y = iris.target  # 标签 (150,)# 查看特征名称和目标类别
print("特征名称:", iris.feature_names)
print("目标类别:", iris.target_names)# 将数据转换为DataFrame便于可视化
iris_df = pd.DataFrame(X, columns=iris.feature_names)
iris_df['species'] = y
iris_df['species'] = iris_df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})# 绘制特征分布图
sns.pairplot(iris_df, hue='species', palette='husl')
plt.suptitle("鸢尾花特征分布", y=1.02)
plt.show()

在这里插入图片描述

特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
目标类别: ['setosa' 'versicolor' 'virginica']

3. 数据预处理

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)print(f"训练集样本数: {len(X_train)}")
print(f"测试集样本数: {len(X_test)}")# 特征标准化(逻辑回归通常需要)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
训练集样本数: 120
测试集样本数: 30

4. 模型训练与评估

# 构建逻辑回归模型
log_reg = LogisticRegression(C=1000,                # 正则化强度的倒数solver='sag',          # 随机平均梯度下降max_iter=1000,         # 最大迭代次数random_state=42
)
# 使用 OneVsRestClassifier 包装
ovr_classifier = OneVsRestClassifier(log_reg)# 训练模型
ovr_classifier.fit(X_train, y_train)# 在训练集和测试集上评估
train_acc = ovr_classifier.score(X_train, y_train)
test_acc = ovr_classifier.score(X_test, y_test)print(f"训练集准确率: {train_acc:.2%}")
print(f"测试集准确率: {test_acc:.2%}")# 更详细的评估报告
y_pred = ovr_classifier.predict(X_test)
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=iris.target_names,yticklabels=iris.target_names)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
plt.show()

在这里插入图片描述

训练集准确率: 96.67%
测试集准确率: 96.67%分类报告:precision    recall  f1-score   supportsetosa       1.00      1.00      1.00        10versicolor       1.00      0.90      0.95        10virginica       0.91      1.00      0.95        10accuracy                           0.97        30macro avg       0.97      0.97      0.97        30
weighted avg       0.97      0.97      0.97        30

5. 决策边界可视化

# 为可视化,只使用两个主要特征
X_train_2d = X_train[:, :2]
X_test_2d = X_test[:, :2]# 重新训练一个2D模型
log_reg_2d = LogisticRegression(C=1000,solver='sag',max_iter=2000,random_state=42
)ovr_classifier_2d = OneVsRestClassifier(log_reg_2d)
ovr_classifier_2d.fit(X_train_2d, y_train)  # 必须先训练模型# 创建网格点
x_min, x_max = X_train_2d[:, 0].min() - 1, X_train_2d[:, 0].max() + 1
y_min, y_max = X_train_2d[:, 1].min() - 1, X_train_2d[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))# 预测每个网格点的类别
Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)# 预测每个网格点的类别
Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)# 绘制决策边界
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='Pastel2')
scatter = plt.scatter(X_train_2d[:, 0], X_train_2d[:, 1], c=y_train,cmap='Dark2', edgecolor='black')# 添加图例和标签
legend_elements = scatter.legend_elements()[0]
plt.legend(legend_elements,iris.target_names,title="鸢尾花种类")
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.title("逻辑回归多分类决策边界")
plt.show()

在这里插入图片描述

四、关键参数解析

在构建逻辑回归模型时的重要参数:

  1. C=1000:正则化强度的倒数,较小的值表示更强的正则化。这里设为较大的值,相当于减少正则化。
  2. multi_class=‘ovr’:指定使用One-vs-Rest策略处理多分类问题。scikit-learn还支持’multinomial’选项,使用softmax函数直接进行多分类。
  3. solver=‘sag’:优化算法选择随机平均梯度下降(Stochastic Average Gradient),适合大数据集。其他可选算法包括:‘liblinear’:适合小数据集;‘newton-cg’:牛顿法;‘lbfgs’:拟牛顿法。
  4. max_iter=1000:最大迭代次数,确保模型能够收敛。

五、总结

  1. 逻辑回归如何通过One-vs-All策略处理多分类问题
  2. 完整的鸢尾花分类实现流程
  3. 模型评估与可视化方法

相关文章:

  • 【源码+文档+调试讲解】儿童图书推荐系统81
  • 论文笔记(八十三)STACKGEN: Generating Stable Structures from Silhouettes via Diffusion
  • C++负载均衡远程调用学习之QPS性能测试
  • 个人健康中枢的多元化AI软件革新与精准健康路径探析
  • 同城跑腿小程序帮取帮送接单抢单预约取件智能派单同城配送全开源运营版源码优创
  • 2000-2022年上市公司数字经济专利申请数据
  • 组件通信-mitt
  • 【云备份】配置文件加载模块
  • 中小企业MES系统需求文档
  • 创新老年综合评估实训室建设方案,规范评估流程
  • JSON与字典的区别及示例
  • (六——下)RestAPI 毛子(Http resilience/Refit/游标分页)
  • Linux52 运行百度网盘 解决故障无法访问repo nosandbox 未解决:疑似libstdc++版本低导致无法运行baidu网盘
  • Arduino逻辑控制详细解答,一点自己的想法记录
  • Shell 脚本基础
  • 文献阅读篇#7:5月一区好文阅读,BFA-YOLO,用于建筑信息建模!(下)
  • 【记录】新Ubuntu20配置voxelmap的环境安装
  • w317汽车维修预约服务系统设计与实现
  • ThreadLocal理解
  • SALOME源码分析: 命令系统
  • 华尔兹转岗与鲁比奥集权:特朗普政府人事震荡背后的深层危机
  • 用小型核反应堆给数据中心供电,国内企业正在开展项目论证
  • 五大白酒去年净利超1500亿元:贵州茅台862亿领跑,洋河营收净利齐降
  • 苹果手机为何无法在美制造?全球供应链难迁移
  • 周劼已任中国航天科技集团有限公司董事、总经理、党组副书记
  • 为治理商家“卷款跑路”“退卡难”,预付式消费司法解释5月起实施