当前位置: 首页 > news >正文

Pytorch实现感知器并实现分类动画

这个实现包含以下关键部分:

  1. 数据生成:使用用户提供的函数生成两类可线性分离的数据点。

  2. 感知机模型

    • 一个线性层接收二维输入并输出一个值
    • 不使用激活函数(原始感知机形式)
    • 使用均方误差损失函数(MSE)和随机梯度下降优化器
  3. 动态可视化

    • 使用 matplotlib 的 FuncAnimation 创建动画
    • 每帧更新显示当前决策边界和损失值
    • 数据点根据真实标签着色(蓝色为 - 1,红色为 1)
    • 绿色线表示当前感知机的决策边界

运行代码后,你将看到一个动画展示感知机如何逐步学习区分两类数据的决策边界。随着训练的进行,决策边界会不断调整,直到能够正确分离两个类别。

 

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation# 数据生成函数(保持与用户提供的一致)
def generate_data():np.random.seed(0)class_1 = np.random.randn(100, 2) + np.array([2, 2])class_2 = np.random.randn(100, 2) + np.array([-2, -2])labels_1 = np.ones((100, 1))labels_2 = -np.ones((100, 1))data = np.vstack((class_1, class_2))labels = np.vstack((labels_1, labels_2))return torch.Tensor(data), torch.Tensor(labels)# 感知机模型
class Perceptron(nn.Module):def __init__(self):super(Perceptron, self).__init__()self.linear = nn.Linear(2, 1)  # 二维输入,一维输出def forward(self, x):return self.linear(x)# 训练和可视化函数
def train_and_visualize():# 生成数据X, y = generate_data()# 创建模型、损失函数和优化器model = Perceptron()criterion = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 设置图形fig, ax = plt.subplots(figsize=(10, 8))scatter = ax.scatter(X[:, 0], X[:, 1], c=y.numpy().flatten(), cmap='coolwarm', alpha=0.7)line, = ax.plot([], [], 'g-', lw=2)ax.set_xlim(-6, 6)ax.set_ylim(-6, 6)ax.set_title('Perceptron Classification')# 初始化线def init():line.set_data([], [])return line,# 更新函数def update(frame):# 训练一步optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()# 获取当前权重和偏置w1, w2 = model.linear.weight.data[0]b = model.linear.bias.data[0]# 计算决策边界x_vals = np.linspace(-6, 6, 100)y_vals = -(w1 * x_vals + b) / w2# 更新线line.set_data(x_vals, y_vals)ax.set_title(f'Perceptron Classification (Epoch {frame + 1}, Loss: {loss.item():.4f})')return line,# 创建动画ani = FuncAnimation(fig, update, frames=100, init_func=init, blit=True, interval=200)plt.show()return ani# 运行训练和可视化
if __name__ == "__main__":animation = train_and_visualize()

http://www.dtcms.com/a/277125.html

相关文章:

  • Vivado ILA抓DDR信号(各种IO信号:差分、ISERDES、IOBUFDS等)
  • MacOS使用Multipass快速搭建轻量级k3s集群
  • 在Intel Mac的PyCharm中设置‘add bin folder to the path‘的解决方案
  • COZE token刷新
  • mac上BRPC的CMakeLists.txt优化:解决Protobuf路径问题
  • composer如何安装以及举例在PHP项目中使用Composer安装TCPDF库-优雅草卓伊凡
  • 数据结构1:线性表的顺序存储的定义以及基本操作
  • [Linux 入门] Linux 引导过程、系统管理与故障处理全解析
  • Python 数据建模与分析项目实战预备 Day 4 - EDA(探索性数据分析)与可视化
  • ansible自动化部署考试系统前后端分离项目
  • 09.获取 Python 列表的首尾元素与切片技巧
  • 论文Review 3DGSSLAM GauS-SLAM: Dense RGB-D SLAM with Gaussian Surfels
  • OkHttp SSE 完整总结(最终版)
  • JAVA学习笔记 首个HelloWorld程序-002
  • javaweb-day10案例
  • Linux 系统——管理 MySQL
  • 入职华为od一个月的感受
  • 2025年渗透测试面试题总结-2025年HW(护网面试) 44(题目+回答)
  • 鸿蒙项目构建配置
  • TDengine 使用最佳实践(2)
  • SpringBoot-23-企业云端开发实践之Vue框架组件化开发和第三方组件element-ui
  • 谷歌推出Vertex AI Memory Bank:为AI智能体带来持久记忆,支持连续对话
  • 【源力觉醒 创作者计划】文心开源大模型ERNIE-4.5私有化部署保姆级教程与多功能界面窗口部署
  • zotero自由编辑参考文献格式(2)
  • Dubbo + Spring Boot + Zookeeper 快速搭建分布式服务
  • spring--xml注入时bean的property属性
  • 20250713-`Seaborn.pairplot` 的使用注意事项
  • jenkins部署前端vue项目使用Docker+Jenkinsfile方式
  • 【PTA数据结构 | C语言版】字符串插入操作
  • java.net.InetAddress