当前位置: 首页 > news >正文

【机器学习入门】9.2:感知机 Python 实践代码模板(苹果香蕉分类任务适配)

以下代码基于numpymatplotlib手动实现感知机,包含感知机类定义、苹果香蕉分类数据集训练、决策边界可视化,全程贴合 “输入→加权求和→激活输出” 的核心逻辑,亲手体验感知机的学习过程。

图片来源于网络,仅供学习参考

一、环境依赖

确保安装所需 Python 库,未安装则执行以下命令:

pip install numpy matplotlib pandas

二、完整代码实现

1. 导入所需库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

2. 构建 “苹果香蕉分类” 数据集

基于文档中的特征定义,构造带标签的训练集与测试集,确保数据贴合 “颜色 + 形状” 的分类逻辑:

# 设置随机种子,保证结果可复现
np.random.seed(42)# 1. 定义特征与标签(参考文档:颜色=1红色/苹果,-1黄色/香蕉;形状=1圆形/苹果,-1弯形/香蕉)
# 训练集:包含苹果(标签1)和香蕉(标签0)样本,添加少量噪声模拟真实数据
train_data = {"颜色(x1)": [1, 1, 1, -1, -1, -1, 1.1, -0.9, 1.05, -1.05],  # 苹果稍正,香蕉稍负,加微小噪声"形状(x2)": [1, 1.1, 0.9, -1, -0.9, -1.1, 1.05, -0.95, 0.95, -1.05],"标签(y)": [1, 1, 1, 0, 0, 0, 1, 0, 1, 0]  # 1=苹果,0=香蕉
}# 测试集:新增未见过的样本,验证模型泛化能力
test_data = {"颜色(x1)": [1.02, -0.98, 0.99, -1.01],"形状(x2)": [0.98, -1.02, 1.01, -0.99],"标签(y)": [1, 0, 1, 0]
}# 2. 转换为numpy数组(便于计算)
# 训练集
X_train = np.array([train_data["颜色(x1)"], train_data["形状(x2)"]]).T  # 形状:(10, 2),10个样本,2个特征
y_train = np.array(train_data["标签(y)"]).reshape(-1, 1)  # 转为列向量,形状:(10, 1)# 测试集
X_test = np.array([test_data["颜色(x1)"], test_data["形状(x2)"]]).T  # 形状:(4, 2)
y_test = np.array(test_data["标签(y)"]).reshape(-1, 1)# 查看数据集基本信息
print("训练集形状(样本数, 特征数):", X_train.shape)
print("训练集标签形状:", y_train.shape)
print("\n训练集前5行:")
print(np.hstack([X_train[:5], y_train[:5]]))  # 合并特征与标签展示
print("\n测试集:")
print(np.hstack([X_test, y_test]))

3. 手动实现感知机类(核心代码)

定义感知机类,包含 “参数初始化”“激活函数”“前向传播”“参数更新(学习过程)”“预测” 核心功能,完全贴合文档中的感知机模型结构:

class Perceptron:def __init__(self, input_dim, learning_rate=0.05, epochs=100):"""初始化感知机参数:input_dim: 输入特征维度(此处为2:颜色+形状)learning_rate: 学习率(控制参数更新幅度,默认0.05)epochs: 训练迭代次数(默认100,可调整)"""# 1. 初始化参数:权重w(input_dim维)、偏置b(1维)# 权重初始化为小随机数,偏置初始化为0(参考文档中的初始设置)self.w = np.random.normal(loc=0, scale=0.1, size=(input_dim, 1))  # 形状:(2, 1)self.b = 0.0  # 偏置,对应文档中的“内部强度b”# 2. 超参数设置self.lr = learning_rateself.epochs = epochsself.loss_history = []  # 记录训练过程的损失,用于可视化self.w_history = [self.w.copy()]  # 记录权重变化,用于后续分析self.b_history = [self.b]  # 记录偏置变化def step_activation(self, v):"""阶跃激活函数(文档中用于苹果香蕉分类的激活函数)参数:v: 净输入(加权求和+偏置)返回:y: 0或1(分类结果)"""return np.where(v >= 0, 1, 0).reshape(-1, 1)  # 向量化计算,避免循环def relu_activation(self, v):"""ReLU激活函数(文档中例题使用的激活函数,备用)"""return np.where(v >= 0, v, 0).reshape(-1, 1)def forward(self, X, activation="step"):"""前向传播:计算预测输出参数:X: 输入数据(形状:(样本数, input_dim))activation: 激活函数类型("step"或"relu",默认"step")返回:y_pred: 预测输出(形状:(样本数, 1))v: 净输入(加权求和+偏置,用于后续分析)"""# 计算净输入v = X·w + b(向量化计算,支持批量样本)v = np.dot(X, self.w) + self.b# 应用激活函数if activation == "step":y_pred = self.step_activation(v)elif activation == "relu":y_pred = self.relu_activation(v)else:raise ValueError("激活函数仅支持'step'或'relu'")return y_pred, vdef compute_loss(self, y_true, y_pred):"""计算损失:二分类任务用“均方误差”(简单直观,适合入门)"""return np.mean((y_true - y_pred) ** 2)def train(self, X_train, y_train, activation="step"):"""感知机训练过程(核心:误差驱动的参数更新)参数:X_train: 训练集特征y_train: 训练集标签activation: 激活函数类型"""print("\n开始训练感知机...")for epoch in range(1, self.epochs + 1):# 1. 前向传播:计算所有样本的预测输出y_pred, _ = self.forward(X_train, activation=activation)# 2. 计算训练损失train_loss = self.compute_loss(y_train, y_pred)self.loss_history.append(train_loss)# 3. 计算误差(真实标签 - 预测输出)error = y_train - y_pred  # 形状:(样本数, 1)# 4. 参数更新(参考文档中的学习逻辑,向量化计算,无需循环)# 权重更新:w = w + lr * X.T · error(利用矩阵乘法批量更新)self.w += self.lr * np.dot(X_train.T, error)# 偏置更新:b = b + lr * error的均值(偏置对所有样本的误差贡献相同)self.b += self.lr * np.mean(error)# 5. 记录参数变化(用于后续分析)self.w_history.append(self.w.copy())self.b_history.append(self.b)# 6. 定期打印训练信息(每10轮打印一次)if epoch % 10 == 0:# 计算训练精度(预测值与真实标签的一致率)train_acc = accuracy_score(y_train.astype(int), y_pred.astype(int))print(f"Epoch {epoch:3d} | 训练损失:{train_loss:.4f} | 训练精度:{train_acc:.4f}")# 训练结束后打印最终参数print(f"\n训练完成!最终参数:")print(f"权重w = [{self.w[0][0]:.4f}, {self.w[1][0]:.4f}](对应颜色、形状特征)")print(f"偏置b = {self.b:.4f}")def predict(self, X, activation="step"):"""模型预测:输入新样本,输出分类结果"""y_pred, _ = self.forward(X, activation=activation)return y_pred.astype(int)  # 转为整数(0或1),便于后续评估

4. 感知机训练与测试(苹果香蕉分类任务)

初始化感知机,使用阶跃激活函数(文档中苹果香蕉分类的指定激活函数),执行训练并验证测试集效果:

# 1. 初始化感知机(输入维度=2:颜色+形状)
perceptron = Perceptron(input_dim=X_train.shape[1],  # input_dim=2learning_rate=0.05,          # 学习率,可调整epochs=100                   # 训练100轮
)# 2. 训练感知机(使用阶跃激活函数,贴合文档案例)
perceptron.train(X_train, y_train, activation="step")# 3. 测试集预测与评估
y_test_pred = perceptron.predict(X_test, activation="step")
test_acc = accuracy_score(y_test.astype(int), y_test_pred)
test_loss = perceptron.compute_loss(y_test, y_test_pred)# 输出测试集结果
print("\n" + "="*50)
print("感知机测试集评估结果(苹果香蕉分类)")
print("="*50)
print(f"测试集损失:{test_loss:.4f}")
print(f"测试集精度:{test_acc:.4f}")
print("\n测试集预测详情:")
for i in range(len(X_test)):true_label = "苹果" if y_test[i][0] == 1 else "香蕉"pred_label = "苹果" if y_test_pred[i][0] == 1 else "香蕉"print(f"样本{i+1}:颜色={X_test[i][0]:.2f}, 形状={X_test[i][1]:.2f} | 真实:{true_label} | 预测:{pred_label}")

5. 结果可视化(训练损失 + 决策边界 + 样本分布)

通过 3 类图表直观展示感知机的学习过程(损失变化)、分类效果(决策边界),帮助理解 “参数如何影响分类结果”:

# 设置中文字体(避免matplotlib中文乱码)
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False# 创建画布(2行2列,包含4个子图)
fig, axes = plt.subplots(2, 2, figsize=(16, 12))# 1. 训练损失变化曲线(左上子图)
axes[0, 0].plot(range(1, len(perceptron.loss_history)+1), perceptron.loss_history,color='#2E86AB', linewidth=2, label='训练损失')
axes[0, 0].set_xlabel("训练迭代次数(Epoch)", fontsize=12)
axes[0, 0].set_ylabel("均方误差损失", fontsize=12)
axes[0, 0].set_title("感知机训练损失变化曲线", fontsize=14)
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)# 2. 训练集样本分布与最终决策边界(右上子图)
# 划分特征范围,用于绘制决策边界
x1_min, x1_max = X_train[:, 0].min() - 0.2, X_train[:, 0].max() + 0.2
x2_min, x2_max = X_train[:, 1].min() - 0.2, X_train[:, 1].max() + 0.2
xx1, xx2 = np.meshgrid(np.linspace(x1_min, x1_max, 100),np.linspace(x2_min, x2_max, 100))
# 生成网格点的预测结果(用于绘制决策边界)
grid_points = np.c_[xx1.ravel(), xx2.ravel()]  # 形状:(10000, 2)
grid_pred, _ = perceptron.forward(grid_points, activation="step")
grid_pred = grid_pred.reshape(xx1.shape)  # 形状:(100, 100)# 绘制决策边界(填充不同类别区域)
axes[0, 1].contourf(xx1, xx2, grid_pred, alpha=0.3, cmap=plt.cm.RdYlBu)
# 绘制训练集样本(苹果=红色圆形,香蕉=蓝色叉号)
apple_mask = y_train[:, 0] == 1
banana_mask = y_train[:, 0] == 0
axes[0, 1].scatter(X_train[apple_mask, 0], X_train[apple_mask, 1], c='red', marker='o', s=80, label='苹果(标签1)', alpha=0.8)
axes[0, 1].scatter(X_train[banana_mask, 0], X_train[banana_mask, 1], c='blue', marker='x', s=80, label='香蕉(标签0)', alpha=0.8)
# 添加决策边界公式(感知机决策边界:w1x1 + w2x2 + b = 0 → x2 = (-w1x1 -b)/w2)
x1_line = np.linspace(x1_min, x1_max, 100)
w1, w2 = perceptron.w[0][0], perceptron.w[1][0]
b = perceptron.b
x2_line = (-w1 * x1_line - b) / w2  # 决策边界的函数表达式
axes[0, 1].plot(x1_line, x2_line, color='black', linewidth=2, label='决策边界')
# 设置标签与标题
axes[0, 1].set_xlabel("颜色特征(x1)(1=红色,-1=黄色)", fontsize=12)
axes[0, 1].set_ylabel("形状特征(x2)(1=圆形,-1=弯形)", fontsize=12)
axes[0, 1].set_title("训练集样本分布与感知机决策边界", fontsize=14)
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)# 3. 测试集样本分布与决策边界(左下子图)
# 绘制决策边界(与训练集一致)
axes[1, 0].contourf(xx1, xx2, grid_pred, alpha=0.3, cmap=plt.cm.RdYlBu)
# 绘制测试集样本
test_apple_mask = y_test[:, 0] == 1
test_banana_mask = y_test[:, 0] == 0
axes[1, 0].scatter(X_test[test_apple_mask, 0], X_test[test_apple_mask, 1], c='red', marker='o', s=100, label='苹果(标签1)', edgecolors='black', alpha=0.8)
axes[1, 0].scatter(X_test[test_banana_mask, 0], X_test[test_banana_mask, 1], c='blue', marker='x', s=100, label='香蕉(标签0)', edgecolors='black', alpha=0.8)
# 绘制决策边界
axes[1, 0].plot(x1_line, x2_line, color='black', linewidth=2, label='决策边界')
# 设置标签与标题
axes[1, 0].set_xlabel("颜色特征(x1)(1=红色,-1=黄色)", fontsize=12)
axes[1, 0].set_ylabel("形状特征(x2)(1=圆形,-1=弯形)", fontsize=12)
axes[1, 0].set_title(f"测试集样本分布与决策边界(精度:{test_acc:.4f})", fontsize=14)
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)# 4. 权重变化曲线(右下子图)
# 提取权重历史(w1=颜色权重,w2=形状权重)
w1_history = [w[0][0] for w in perceptron.w_history]
w2_history = [w[1][0] for w in perceptron.w_history]
# 绘制权重变化
axes[1, 1].plot(range(len(w1_history)), w1_history, color='#A23B72', linewidth=2, label='颜色权重(w1)')
axes[1, 1].plot(range(len(w2_history)), w2_history, color='#F18F01', linewidth=2, label='形状权重(w2)')
# 标注最终权重
axes[1, 1].scatter(len(w1_history)-1, w1_history[-1], color='#A23B72', s=80, zorder=5)
axes[1, 1].scatter(len(w2_history)-1, w2_history[-1], color='#F18F01', s=80, zorder=5)
axes[1, 1].annotate(f'最终w1={w1_history[-1]:.4f}', xy=(len(w1_history)-1, w1_history[-1]), xytext=(len(w1_history)-20, w1_history[-1]+0.1),arrowprops=dict(arrowstyle='->', color='#A23B72'),fontsize=10)
axes[1, 1].annotate(f'最终w2={w2_history[-1]:.4f}', xy=(len(w2_history)-1, w2_history[-1]), xytext=(len(w2_history)-20, w2_history[-1]-0.1),arrowprops=dict(arrowstyle='->', color='#F18F01'),fontsize=10)
# 设置标签与标题
axes[1, 1].set_xlabel("训练迭代次数(Epoch)", fontsize=12)
axes[1, 1].set_ylabel("权重值", fontsize=12)
axes[1, 1].set_title("感知机权重变化曲线(训练过程)", fontsize=14)
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)# 调整子图间距,保存图片(可直接插入CSDN推文)
plt.tight_layout()
plt.savefig("perceptron_apple_banana.png", dpi=300, bbox_inches='tight')
plt.show()

三、代码使用说明与结果解读

1. 代码适配性

  • 超参数调整:若训练损失下降缓慢,可增大learning_rate(如 0.05→0.1);若损失震荡不收敛,可减小学习率或增加epochs(如 100→200);
  • 激活函数切换:若想验证文档中的 ReLU 函数例题,可在trainpredict方法中设置activation="relu",并替换数据集为例题中的输入(X=(-0.8, 0.2, -0.4)w=(1.0, -0.75, 0.25)b=-1)。

2. 关键结果解读

  • 训练损失曲线:随着迭代次数增加,损失应从 0.5 左右逐渐下降至 0(或接近 0),说明感知机在持续学习分类规律;
  • 决策边界:最终决策边界会清晰地将 “苹果(红色圆形)” 和 “香蕉(蓝色叉号)” 分开,符合文档中 “w1x1 + w2x2 + b = 0” 的决策逻辑 —— 比如文档中w1=1, w2=1, b=0时,决策边界是x1 + x2 = 0,与代码训练后的边界趋势一致;
  • 权重变化:训练后 “颜色权重 (w1)” 和 “形状权重 (w2)” 会稳定在 1 左右(与文档初始设置相近),说明感知机自动学习到 “颜色和形状对分类同等重要”;
  • 测试集精度:测试集精度通常能达到 1.0(或接近 1.0),证明感知机对未见过的样本仍能准确分类,泛化能力良好。
http://www.dtcms.com/a/548786.html

相关文章:

  • 大会的网站架构企业网站设计的基本内容包括哪些
  • 打印对称的X。
  • 生产管理系统详解:生产产品,bom,生产线,生产工序,bom清单,生产订单,生产任务单,他们之间的关系梳理
  • 企业微信SCRM系统有什么作用,满足哪些功能?从获客到提效的功能适配逻辑
  • JS如何操作IndexedDB
  • 网站正在维护中wordpress 评分
  • Kafka关闭日志,启动一直打印日志
  • 搬家网站建设思路荆门哪里有专门做企业网站的
  • 前后端分离
  • curl开发常用方法总结
  • rust实战:基础框架Rust + Tauri Windows 桌面应用开发文档
  • knife4j在配置文件(xml文件)的配置错误
  • Java的多线程——多线程(二)
  • 小企业也能用AI?低成本智能转型实战案例
  • ros2 播放 ros1 bag
  • 网页设计做一个网站设计之家官方网站
  • 基于STM32单片机 + DeepSeek-OCR 的智能文档扫描助手设计与实现
  • 微信小程序如何传递参数
  • 【数据结构】:数组及特殊矩阵
  • 记录一下微信小程序里使用SSE
  • API 接口安全:用 JWT + Refresh Token 解决 Token 过期与身份伪造问题
  • 云手机搬砖 高效采集资源
  • GitHub Actions CI/CD 自动化部署完全指南
  • Fastlane 结合 开心上架 命令行版本实现跨平台上传发布 iOS App
  • 广东营销网站建设服务公司军事信息化建设网站
  • Go Web 编程快速入门 14 - 性能优化与最佳实践:Go应用性能分析、内存管理、并发编程最佳实践
  • LeetCode每日一题——合并两个有序链表
  • 丽江市建设局官方网站门户网站开发需要多少钱
  • 边缘计算中评估多模态分类任务的延迟
  • 11.9.16.Filter(过滤器)