当前位置: 首页 > news >正文

用Python手写一个能识花的感知器模型——Iris分类实战详解

在这里插入图片描述

描述

感知器是最简单的线性二分类模型。尽管它很基础,但在实际场景中仍有用处,尤其是在:

  1. 教学/入门:帮助学生理解线性分类、梯度思想与决策边界。
  2. 低算力嵌入式设备:当模型必须非常小且推断简单(例如设备上用一个线性规则快速判断两类状态)时,感知器可作为简单筛选器。
  3. 数据快速原型:在需要快速判断某两个易分离类别时(如质量控制中“合格/不合格”的某种材质特征),感知器能快速给出一个基线性能。
  4. 可解释性需求强的应用:感知器决策来自线性组合,容易解释权重对特征的影响。

在本文示例里,我们用萼片长度(sepal length)和花瓣长度(petal length)来区分 setosa 与 versicolor。这两维在前 100 个样本中本身就比较有区分力,所以非常适合作为教学与演示数据集。

题解答案

  1. 载入 sklearn.datasets.load_iris,取前 100 个样本(0–99),对应两类(0:setosa,1:versicolor)。

  2. 从每个样本提取第 1 列(萼片长度)与第 3 列(花瓣长度)作为特征矩阵 (X \in \mathbb{R}^{100 \times 2})。

  3. 将标签 y 中的 0 替换为 -1,1 替换为 +1,适应感知器的符号输出。

  4. 实现一个感知器类(从零写,包含 fitpredictnet_input),记录每一轮(epoch)错误分类的数量。

  5. 训练后:

    • 绘制样本散点图(两类不同标记)和决策边界(可视化)。
    • 绘制训练过程中的错误数随迭代变化图(用以查看是否收敛)。
  6. 给出测试/训练准确率与模型权重解读,并分析复杂度与空间占用。

题解代码

下面给出完整、清晰、注释充足的 Python 代码。运行前请确保已安装 numpy, scikit-learn, matplotlib

# perceptron_iris_demo.py
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as pltclass Perceptron:"""简单的感知器实现(批量感知器,逐样本更新)。参数:eta: 学习率 (float)n_iter: 训练轮数 (int)属性:w_: 权重向量(包含偏置),形状 (n_features + 1,)errors_: 每个 epoch 的错误分类数列表"""def __init__(self, eta=0.01, n_iter=10):self.eta = etaself.n_iter = n_iterdef fit(self, X, y):# X: (n_samples, n_features), y: { -1, 1 }n_samples, n_features = X.shape# 初始化权重(一个额外的偏置权重)self.w_ = np.zeros(n_features + 1)self.errors_ = []for epoch in range(self.n_iter):errors = 0for xi, target in zip(X, y):update = self.eta * (target - self.predict(xi))# 如果 update != 0,说明预测有误(或不完全等于目标),按感知器规则更新if update != 0.0:# w_j = w_j + update * x_jself.w_[1:] += update * xi# bias: w_0 = w_0 + updateself.w_[0] += updateerrors += 1self.errors_.append(errors)# 可选:打印每轮的错误数,便于调试# print(f"Epoch {epoch+1}/{self.n_iter}, errors: {errors}")return selfdef net_input(self, X):# 线性组合:w_0 * 1 + sum(w_j * x_j)return np.dot(X, self.w_[1:]) + self.w_[0]def predict(self, X):# 对单个样本或批量进行预测# 返回 +1 或 -1net = self.net_input(X)# 当 X 是单样本时 net 是标量;当是数组时是数组return np.where(net >= 0.0, 1, -1)def main():# 1. 加载数据并预处理iris = load_iris()data = iris.datatarget = iris.target# 只取前100个样本(setosa 和 versicolor),并取第1列和第3列作为特征(索引0和2)X = data[0:100, [0, 2]]  # (100, 2)y = target[0:100]        # 0或1# 将 0 -> -1,1 -> +1y = np.where(y == 0, -1, 1)# 2. 数据可视化(散点图)index_0 = np.where(y == -1)index_1 = np.where(y == 1)plt.figure(figsize=(6, 4))plt.scatter(X[index_0, 0], X[index_0, 1], marker='x', label='setosa (-1)')plt.scatter(X[index_1, 0], X[index_1, 1], marker='o', label='versicolor (+1)')plt.xlabel('萼片长度 (sepal length)')plt.ylabel('花瓣长度 (petal length)')plt.legend(loc='lower right')plt.title('Iris 子集(前100个样本)- 特征散点图')plt.show()# 3. 训练感知器并记录错误数ppn = Perceptron(eta=0.1, n_iter=10)ppn.fit(X, y)# 4. 绘制每轮的错误数量(学习曲线)plt.figure(figsize=(6, 4))plt.plot(range(1, len(ppn.errors_) + 1), ppn.errors_, marker='o')plt.xlabel('训练轮数 (epoch)')plt.ylabel('错误分类数')plt.title('感知器训练过程 - 错误数随迭代的变化')plt.grid(True)plt.show()# 5. 查看训练结果:权重和训练集准确率print("训练得到的权重(包含偏置w_0):", ppn.w_)y_pred = ppn.predict(X)accuracy = np.mean(y_pred == y)print(f"训练集准确率: {accuracy * 100:.2f}%")# 6. 可选:绘制决策边界(二维特征空间)# 创建网格来绘制决策边界x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),np.linspace(y_min, y_max, 200))grid = np.c_[xx.ravel(), yy.ravel()]Z = ppn.predict(grid)Z = Z.reshape(xx.shape)plt.figure(figsize=(6, 4))plt.contourf(xx, yy, Z, alpha=0.2)plt.scatter(X[index_0, 0], X[index_0, 1], marker='x', label='setosa (-1)')plt.scatter(X[index_1, 0], X[index_1, 1], marker='o', label='versicolor (+1)')plt.xlabel('萼片长度 (sepal length)')plt.ylabel('花瓣长度 (petal length)')plt.legend(loc='lower right')plt.title('感知器决策边界(灰色区域表示分类)')plt.show()if __name__ == "__main__":main()

题解代码分析

数据加载与预处理

iris = load_iris()
data = iris.data
target = iris.target
X = data[0:100, [0, 2]]
y = target[0:100]
y = np.where(y == 0, -1, 1)
  • load_iris() 返回一个包含 data(150×4)和 target (150,) 的字典样对象。
  • 只取前 100 个样本 (0..99),这两个类别是线性可分性较好的组合:setosa 与 versicolor。
  • 选择第 1 列(萼片长度,index=0)与第 3 列(花瓣长度,index=2)作为演示用的二维特征,便于可视化。
  • 将标签从 {0,1} 映射到 {-1, +1},因为经典感知器输出是符号函数(sign)。

感知器类(Perceptron)

class Perceptron:def __init__(self, eta=0.01, n_iter=10):...def fit(self, X, y):...def net_input(self, X):...def predict(self, X):...
  • w_ 含偏置,长度为 n_features + 1w_[0] 是偏置(bias/intercept)。
  • fit 中采用逐样本(online)更新规则:w <- w + eta * (target - predict(x)) * x(对偏置 x0=1 用同样的更新)。
  • 这里 predict 使用 np.where(net >= 0.0, 1, -1),决定阈值为 0。
  • errors_ 保存每轮的错误数,用来绘制学习曲线与判断是否收敛。

可视化:散点图、错误曲线、决策边界

  • 散点图:直观查看两类在所选二维特征空间中的分布。
  • 错误曲线:如果感知器能线性分离数据,错误数通常会下降到 0 并维持;若不能完全线性分离,错误数会在某个水平徘徊。
  • 决策边界:我们用网格对整个特征空间进行预测并绘制等高面(实质是分类分隔面的可视化),这样就能看到模型如何把空间切分为两类。

示例测试及结果

运行方式:将上面的代码保存为 perceptron_iris_demo.py,在支持的 Python 环境下运行:

python perceptron_iris_demo.py

预期输出与解释

  1. 首先弹出一个散点图窗口:可以看到 setosa(x 标记)与 versicolor(o 标记)大体上能用一条线分开(setosa 的花瓣长度普遍更短)。
  2. 然后弹出训练错误数随 epoch 变化的折线图。对于前100个样本(两个类别),在合理的学习率和轮数下,通常会在若干 epoch 后收敛到 0 错误或极低错误数(因为这两个类别在这两个维度上几乎线性可分)。
  3. 控制台会输出类似:
训练得到的权重(包含偏置w_0): [ -3.2   0.8  1.5 ]
训练集准确率: 100.00%

(这里的数字会因随机初始化/实现差异与超参不同而不同,上面只是示意)。
4. 最后会弹出一个决策边界图,可以直观看到分类线把平面分为两块,与散点分布吻合。

示例说明:如果训练集准确率接近 100%,说明感知器在这两个类别与选定特征下很好地学习到了线性分隔边界。若准确率显著低于 100%,可能原因包括学习率过小/过大、epoch 太少或数据并非线性可分(在选取不同特征时尤为常见)。

时间复杂度

设 (n) 为样本数,(d) 为特征维度,(T) 为训练轮数(epochs)。

  • 训练时间复杂度:每个 epoch 中对每个样本做一次预测(内积 (O(d)))并在错误时更新权重(更新也是 (O(d)))。总体为 (O(T \cdot n \cdot d))。

    • 对于本文:(n=100, d=2),所以开销极小,适合交互式演示或嵌入式场景。
  • 预测时间复杂度:对单样本预测是一次内积,复杂度 (O(d))。批量预测 (O(n \cdot d))。

空间复杂度

  • 权重向量存储占用 (O(d))(包含一个偏置)。
  • 不考虑数据本身(若需要在内存中保存数据,则为 (O(n \cdot d)))。
  • 训练过程额外只保存 errors_ 长度为 (T) 的列表,空间开销 (O(T))(通常远小于 (n\cdot d))。

因此总体空间复杂度主要被数据存储支配:(O(n \cdot d))。

小结

  • 本文以通俗且实操的方式展示了如何用 Iris 数据集的前 100 个样本(setosa 与 versicolor)训练一个从零实现的感知器,过程包含数据提取、标签映射、训练、学习曲线和决策边界可视化。
  • 感知器适合线性可分的二分类问题:若数据线性可分,经过若干迭代错误数会降到 0;否则会震荡或停留在某个非零错误数。
  • 实际应用场景很多:教学、轻量级终端判别、快速原型等。若需要更高性能或非线性判别,应该考虑使用支持向量机(SVM)、逻辑回归或神经网络等更强的模型。
  • 最后,代码写得尽量清晰,适合直接运行并修改超参数(学习率 eta、迭代次数 n_iter)或特征选择来观察不同设置的效果。
http://www.dtcms.com/a/531850.html

相关文章:

  • MySQL笔记16
  • gRPC通信流程学习
  • 百度站长平台有哪些功能网站做权重的好处
  • 数据科学复习题2025
  • 牛客网 AI题​(二)机器学习 + 深度学习
  • 拆解AI深度研究:从竞品分析到出海扩张,这是GTM的超级捷径
  • HarmonyOS 环境光传感器自适应:构建智能光线感知应用
  • 护肤品 网站建设策划shopex网站经常出错
  • 机器人描述文件xacro(urdf扩展)
  • AI决策平台怎么选?
  • 当 AI 视觉遇上现代 Web:DeepSeek-OCR 全栈应用深度剖析
  • 紫外工业相机入门介绍和工业检测核心场景
  • 商业求解器和开源求解器哪个更适合企业?
  • 比尤果网做的好的网站深圳网站设计精选刻
  • WPF 控件速查 PDF 笔记(可直接落地版)
  • Selenium+Unittest自动化测试框架
  • 设计模式-命令模式(Command)
  • 设计模式-外观模式(Facade)
  • web自动化测试-selenium_01_元素定位
  • 苏州建设工程信息网站wordpress自动生成tag
  • 学习C#调用OpenXml操作word文档的基本用法(1:读取样式定义)
  • Java-Spring入门指南(二十八)Android界面设计基础
  • Go 语言类型转换
  • 【Windows】goland-2024版安装包
  • 快速入门elasticsearch
  • Linux 多用户服务器限制单用户最大内存使用(systemd user.slice)
  • 食品公司网站设计项目雨蝶直播免费直播
  • SQL 调试不再靠猜:Gudu SQL Omni 让血缘分析一键可视化
  • RV1126 NO.34:OPENCV的交叉编译和项目Makefile讲解
  • FreeRTOS---进阶知识4---通用链表