当前位置: 首页 > news >正文

[torch] xor 分类问题训练

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子以确保结果的可重复性
torch.manual_seed(1)# 生成正确的XOR数据集x_data= np.float32(np.random.rand(500,2))
y_data = np.float32( np.sign(x_data[:,0]-0.5)* np.sign(x_data[:,1]-0.5) +1)/2# 定义模型,一个具有1个隐藏层的多层感知器[6,7](@ref)
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.hidden1 = nn.Linear(in_features=2, out_features=6)  # 隐藏层4个神经元self.hidden2 = nn.Linear(in_features=6, out_features=4)  # 隐藏层4个神经元self.output = nn.Linear(in_features=4, out_features=1)   # 输出层def forward(self, x):x = torch.tanh(self.hidden1(x))  # 使用tanh激活函数x = torch.tanh(self.hidden2(x))  # 使用tanh激活函数x = self.output(x)return xmodel = MyModel()# 定义损失函数 - 使用二元交叉熵[4,8](@ref)
loss_fn = nn.BCEWithLogitsLoss()  # 包含sigmoid和交叉熵计算# 设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 学习率可以适当调高# 准备训练数据
x_tensor = torch.from_numpy(x_data)
y_tensor = torch.from_numpy(y_data)# 训练模型
epochs = 5000  # XOR问题需要更多迭代次数
losses = []predictions = torch.sigmoid(model(x_tensor))
pre_arr=predictions[:,0]
pre_arr=pre_arr.detach().numpy()plt.subplot(131)
plt.scatter(x_data[:,0],x_data[:,1],c=pre_arr,cmap='viridis',label='Untrained')
plt.legend()for epoch in range(1, epochs + 1):# 前向传播y_pred = model(x_tensor)loss = loss_fn(y_pred.reshape(500), y_tensor)# 反向传播和优化optimizer.zero_grad()  # 清空梯度loss.backward()        # 反向传播optimizer.step()       # 更新参数losses.append(loss.item())if epoch % 500 == 0:print(f'Epoch {epoch}: Loss = {loss.item()}')predictions = torch.sigmoid(model(x_tensor))
pre_arr=predictions[:,0]
pre_arr=pre_arr.detach().numpy()plt.subplot(132)
plt.scatter(x_data[:,0],x_data[:,1],c=pre_arr,cmap='viridis',label='Trained')
plt.legend()
plt.subplot(133)
plt.scatter(x_data[:,0],x_data[:,1],c=y_data,cmap='viridis',label='Real Data')
plt.legend()
plt.show()

在这里插入图片描述


文章转载自:

http://4zqBv82O.zsrdp.cn
http://D9qGZyaJ.zsrdp.cn
http://PgOAYYHm.zsrdp.cn
http://VAaFFpaL.zsrdp.cn
http://SScGgJ78.zsrdp.cn
http://qvE26ftY.zsrdp.cn
http://x0Bxymm2.zsrdp.cn
http://FDRrMyDw.zsrdp.cn
http://PMlWbe2N.zsrdp.cn
http://wwCLdRuD.zsrdp.cn
http://J361ksbX.zsrdp.cn
http://7s3EXt5o.zsrdp.cn
http://rYTNuWiC.zsrdp.cn
http://0ZTmpbMD.zsrdp.cn
http://MIcA8qD0.zsrdp.cn
http://9TBxoirG.zsrdp.cn
http://Woqv0eFm.zsrdp.cn
http://nCiDzWYJ.zsrdp.cn
http://xX31qAHT.zsrdp.cn
http://kv8CddkA.zsrdp.cn
http://BYmTIj1m.zsrdp.cn
http://j9CdG9OT.zsrdp.cn
http://6FCWqEPB.zsrdp.cn
http://I7UHa4UC.zsrdp.cn
http://DqhuRChe.zsrdp.cn
http://sqA13EgB.zsrdp.cn
http://haJKiZnG.zsrdp.cn
http://z0hRr6x9.zsrdp.cn
http://7WMb0THT.zsrdp.cn
http://zCiscvel.zsrdp.cn
http://www.dtcms.com/a/387702.html

相关文章:

  • React学习教程,从入门到精通,React 表单完整语法知识点与使用方法(22)
  • ref、reactive和computed的用法
  • Redis哈希类型:高效存储与操作指南
  • MySQL 日志:undo log、redo log、binlog以及MVCC的介绍
  • 棉花、玉米、枸杞、瓜类作物分类提取
  • Python测试框架之pytest详解
  • qt QHPieModelMapper详解
  • MAC Typora 1.8.10无法打开多个md档
  • 零碳园区的 “追光者”:三轴光伏太阳花的技术创新与应用逻辑
  • MAC-Java枚举工具类实现
  • 「数据获取」全国村级点状矢量数据
  • Chromium 138 编译指南 macOS 篇:源代码获取(四)
  • 人工智能概念:NLP任务的评估指标(BLEU、ROUGE、PPL、BERTScore、RAGAS)
  • 机器学习基础:从线性回归到多分类实战
  • 深度学习基础:线性回归与 Softmax 回归全解析,从回归到分类的桥梁
  • Scikit-learn Python机器学习 - 分类算法 - 决策树
  • 【人工智能agent】--dify实现文找图、图找文、图找图
  • 基于 Landsat-8 数据的甘肃省金塔县主要农作物分类
  • 社区补丁的回复及常用链接
  • Pyside6 + QML - 信号与槽01 - Button 触发 Python 类方法
  • 视频理解学习笔记
  • Android Studio 将SVG资源转换成生成xml图
  • 后台管理系统详解:通用的系统架构介绍与说明
  • r-DMT市场报告:深度解析全球研究现状与未来发展趋势
  • 企业网络里的API安全防护指南
  • 了解学习DNS服务管理
  • Pycharm安装步骤
  • 分布式k8s集群管理是如何用karmada进行注册的?
  • FreeRTOS 任务调度与管理
  • CI/CD 实战:GitHub Actions 自动化部署 Spring Boot 项目