鸢尾花分类(KNN)
1. 加载数据集
iris = load_iris()
X = iris.data # 特征数据 (150个样本 x 4个特征)
y = iris.target # 目标变量 (3种类别)
feature_names = iris.feature_names # 特征名称
target_names = iris.target_names # 类别名称# 将数据转换为DataFrame便于分析
df = pd.DataFrame(X, columns=feature_names)
df['species'] = y
df['species'] = df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
2. 数据探索
print("数据集维度:", X.shape)
print("\n特征示例:")
print(df.head())
print("\n统计摘要:")
print(df.describe())
print("\n类别分布:")
print(df['species'].value_counts())
运行结果:
数据集维度: (150, 4)特征示例:sepal length (cm) sepal width (cm) ... petal width (cm) species
0 5.1 3.5 ... 0.2 setosa
1 4.9 3.0 ... 0.2 setosa
2 4.7 3.2 ... 0.2 setosa
3 4.6 3.1 ... 0.2 setosa
4 5.0 3.6 ... 0.2 setosa[5 rows x 5 columns]统计摘要:sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
count 150.000000 150.000000 150.000000 150.000000
mean 5.843333 3.057333 3.758000 1.199333
std 0.828066 0.435866 1.765298 0.762238
min 4.300000 2.000000 1.000000 0.100000
25% 5.100000 2.800000 1.600000 0.300000
50% 5.800000 3.000000 4.350000 1.300000
75% 6.400000 3.300000 5.100000 1.800000
max 7.900000 4.400000 6.900000 2.500000类别分布:
species
setosa 50
versicolor 50
virginica 50
Name: count, dtype: int64
3. 数据可视化
# 可视化分析
plt.figure(figsize=(12, 8))# 使用matplotlib的rcParams设置字体,否则图片中中文可能会乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
3.1 特征分布直方图
# 特征分布直方图
for i, feature in enumerate(feature_names):plt.subplot(2, 2, i+1)sns.histplot(data=df, x=feature, hue='species', kde=True)plt.title(f'{feature} 分布')
plt.tight_layout()
plt.show()
运行结果:
3.2 特征关系散点图
# 特征关系散点图
sns.pairplot(df, hue='species', palette='viridis')
plt.suptitle('特征关系散点图矩阵', y=1.02)
plt.show()
运行结果:
4. 数据预处理
# 划分训练集和测试集 (80%训练, 20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y
)# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
5. 模型训练
使用K近邻算法(KNN):一个新样本的类别(或值)由其周围最相似的“邻居”的类别(或值)决定。
# 使用K近邻算法
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)
6. 模型评估
y_pred = knn.predict(X_test)print("\n测试集准确率: {:.2f}%".format(accuracy_score(y_test, y_pred) * 100))
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=target_names))# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=target_names,yticklabels=target_names)
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')
plt.show()
运行结果:
测试集准确率: 93.33%分类报告:precision recall f1-score supportsetosa 1.00 1.00 1.00 10versicolor 0.83 1.00 0.91 10virginica 1.00 0.80 0.89 10accuracy 0.93 30macro avg 0.94 0.93 0.93 30
weighted avg 0.94 0.93 0.93 30
7. 决策树模型对比
dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(X_train, y_train)
dt_pred = dt.predict(X_test)
print("\n决策树准确率: {:.2f}%".format(accuracy_score(y_test, dt_pred) * 100))# 可视化决策树
plt.figure(figsize=(15, 10))
plot_tree(dt, feature_names=feature_names,class_names=target_names, filled=True)
plt.title('鸢尾花分类决策树')
plt.show()
运行结果:
决策树准确率: 96.67%
8. 新样本预测
new_sample = np.array([[5.1, 3.5, 1.4, 0.2]]) # 新样本数据
new_sample_scaled = scaler.transform(new_sample) # 标准化
prediction = knn.predict(new_sample_scaled)
print("\n新样本预测结果:", target_names[prediction][0])
运行结果:
新样本预测结果: setosa
9. 完整代码
# 导入必要的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.tree import DecisionTreeClassifier, plot_tree# 加载数据集
iris = load_iris()
X = iris.data # 特征数据 (150个样本 x 4个特征)
y = iris.target # 目标变量 (3种类别)
feature_names = iris.feature_names # 特征名称
target_names = iris.target_names # 类别名称# 将数据转换为DataFrame便于分析
df = pd.DataFrame(X, columns=feature_names)
df['species'] = y
df['species'] = df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})# 数据探索
print("数据集维度:", X.shape)
print("\n特征示例:")
print(df.head())
print("\n统计摘要:")
print(df.describe())
print("\n类别分布:")
print(df['species'].value_counts())# 可视化分析
plt.figure(figsize=(12, 8))# 使用matplotlib的rcParams设置字体,否则图片中中文可能会乱码
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 特征分布直方图
for i, feature in enumerate(feature_names):plt.subplot(2, 2, i+1)sns.histplot(data=df, x=feature, hue='species', kde=True)plt.title(f'{feature} 分布')
plt.tight_layout()
plt.show()# 特征关系散点图
sns.pairplot(df, hue='species', palette='viridis')
plt.suptitle('特征关系散点图矩阵', y=1.02)
plt.show()# 数据预处理
# 划分训练集和测试集 (80%训练, 20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y
)# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 模型训练(使用K近邻算法)
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)# 模型评估
y_pred = knn.predict(X_test)print("\n测试集准确率: {:.2f}%".format(accuracy_score(y_test, y_pred) * 100))
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=target_names))# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=target_names,yticklabels=target_names)
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('混淆矩阵')
plt.show()# 使用决策树模型进行对比(可选)
dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(X_train, y_train)
dt_pred = dt.predict(X_test)
print("\n决策树准确率: {:.2f}%".format(accuracy_score(y_test, dt_pred) * 100))# 可视化决策树
plt.figure(figsize=(15, 10))
plot_tree(dt, feature_names=feature_names,class_names=target_names, filled=True)
plt.title('鸢尾花分类决策树')
plt.show()# 进行新样本预测(示例)
new_sample = np.array([[5.1, 3.5, 1.4, 0.2]]) # 新样本数据
new_sample_scaled = scaler.transform(new_sample) # 标准化
prediction = knn.predict(new_sample_scaled)
print("\n新样本预测结果:", target_names[prediction][0])