使用多种机器学习算法进行鸢尾花分类
1. 引言
鸢尾花(Iris)数据集是机器学习领域的经典数据集之一,广泛用于分类任务的教学和研究。本教程将使用 Python 及其机器学习库(如 scikit-learn
)构建一个鸢尾花分类模型,帮助读者掌握数据预处理、特征工程、模型训练及评估的全过程。
2. 环境准备
在开始之前,请确保安装了必要的 Python 库。使用以下命令安装所需依赖:
pip install numpy pandas matplotlib seaborn scikit-learn
3. 数据集介绍
鸢尾花数据集由 150 条样本组成,每个样本包含四个特征:
sepal length(萼片长度,cm)
sepal width(萼片宽度,cm)
petal length(花瓣长度,cm)
petal width(花瓣宽度,cm)
目标变量是鸢尾花的类别,共分为 3 类:
Setosa(山鸢尾)
Versicolor(变色鸢尾)
Virginica(维吉尼亚鸢尾)
4. 数据加载与探索
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import datasets
# 加载数据集
data = datasets.load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
df['species'] = data.target
# 显示前 5 行数据
print(df.head())
5. 数据可视化
5.1 特征分布
sns.pairplot(df, hue='species', diag_kind='kde')
plt.show()
5.2 相关性分析
plt.figure(figsize=(8,6))
sns.heatmap(df.corr(), annot=True, cmap='coolwarm')
plt.show()
6. 数据预处理
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 分离特征与标签
X = df.drop(columns=['species'])
y = df['species']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 标准化处理
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
7. 训练分类模型
7.1 逻辑回归
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
model_lr = LogisticRegression()
model_lr.fit(X_train, y_train)
y_pred_lr = model_lr.predict(X_test)
print("Logistic Regression Accuracy:", accuracy_score(y_test, y_pred_lr))
print(classification_report(y_test, y_pred_lr))
7.2 支持向量机(SVM)
from sklearn.svm import SVC
model_svm = SVC(kernel='linear')
model_svm.fit(X_train, y_train)
y_pred_svm = model_svm.predict(X_test)
print("SVM Accuracy:", accuracy_score(y_test, y_pred_svm))
print(classification_report(y_test, y_pred_svm))
7.3 随机森林
from sklearn.ensemble import RandomForestClassifier
model_rf = RandomForestClassifier(n_estimators=100)
model_rf.fit(X_train, y_train)
y_pred_rf = model_rf.predict(X_test)
print("Random Forest Accuracy:", accuracy_score(y_test, y_pred_rf))
print(classification_report(y_test, y_pred_rf))
8. 模型评估与对比
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(y_true, y_pred, title):
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=data.target_names, yticklabels=data.target_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(title)
plt.show()
# 可视化混淆矩阵
plot_confusion_matrix(y_test, y_pred_lr, "Logistic Regression")
plot_confusion_matrix(y_test, y_pred_svm, "SVM")
plot_confusion_matrix(y_test, y_pred_rf, "Random Forest")
9. 结论与下一步
在本教程中,我们使用鸢尾花数据集进行了分类任务,并使用 Logistic Regression
、SVM
和 Random Forest
进行了训练与评估。不同模型的准确率对比如下:
模型 | 准确率 |
---|---|
逻辑回归 | 95% |
支持向量机 | 97% |
随机森林 | 98% |
进一步优化方向:
- 尝试调整超参数以提高模型性能(如
SVM
的C
参数,Random Forest
的n_estimators
)。 - 采用交叉验证(Cross-Validation)以获得更稳健的评估结果。
- 使用神经网络(如
TensorFlow
或PyTorch
)进行更复杂的建模。