当前位置: 首页 > news >正文

[torch] 非线性拟合问题的训练

利用 torch 建立神经网络,模拟有限个数据的非线性拟合

本文仍然考虑 f(x)=sin⁡(x)xf(x)=\frac{\sin(x)}{x}f(x)=xsin(x) 函数在区间 [-10,10] 上固定数据的拟合。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt# 设置随机种子以确保结果的可重复性
torch.manual_seed(1)# 生成数据集
x_data = (np.random.rand(500) * 20 - 10).astype('float32')  # 生成500个随机x值,范围在-10到10之间
y_data = np.sin(x_data) / x_data  # 生成y值
y_data = y_data.reshape(-1, 1)  # 将y_data转换为二维数组# 定义模型,一个具有2个隐藏层的多层感知器
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.hidden1 = nn.Linear(in_features=1, out_features=50)self.bn = nn.BatchNorm1d(num_features=50)self.hidden2 = nn.Linear(in_features=50, out_features=1)def forward(self, x):x = torch.tanh(self.hidden1(x))x = self.bn(x)x = self.hidden2(x)return xmodel = MyModel()# 定义损失函数
loss_fn = nn.MSELoss()# 设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 准备训练数据 (PyTorch通常使用DataLoader, 但简单回归也可以直接使用Tensor)
x_tensor = torch.from_numpy(x_data).unsqueeze(-1) # 转换为Tensor并增加特征维度
y_tensor = torch.from_numpy(y_data)# 训练模型
epochs = 1000
for epoch in range(1, epochs + 1):# 前向传播y_pred = model(x_tensor)loss = loss_fn(y_pred, y_tensor)# 反向传播和优化optimizer.zero_grad() # 清空梯度loss.backward()       # 反向传播optimizer.step()      # 更新参数if epoch % 100 == 0:print(f'Epoch {epoch}: Loss = {loss.item()}')# 使用训练好的模型进行预测
with torch.no_grad(): # 关闭梯度计算y_pred_np = model(x_tensor).numpy()# 可视化结果
plt.scatter(x_data, y_data, label='True', alpha=0.5)
x,index=torch.sort(torch.as_tensor(x_data))
plt.plot(x, y_pred_np[index],'r', label='Predicted')
plt.legend()
plt.title('Fitting of sinc function')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

在这里插入图片描述


文章转载自:

http://Gq2BRijS.qkqpy.cn
http://L9pOUsNO.qkqpy.cn
http://FhDTTxnM.qkqpy.cn
http://gORn8aMY.qkqpy.cn
http://Im4HNROY.qkqpy.cn
http://owCS9nso.qkqpy.cn
http://qNvTA3TP.qkqpy.cn
http://Uj6d7dGf.qkqpy.cn
http://hsLJh9vv.qkqpy.cn
http://opV2FcSM.qkqpy.cn
http://668uH1mV.qkqpy.cn
http://Dv3vNEPE.qkqpy.cn
http://90KZPb2x.qkqpy.cn
http://GsrV1CO2.qkqpy.cn
http://hSKVzubl.qkqpy.cn
http://85M1SS07.qkqpy.cn
http://PRWI8pRV.qkqpy.cn
http://QDK3H2eO.qkqpy.cn
http://G87s9eTS.qkqpy.cn
http://YUEDkFlW.qkqpy.cn
http://kTpVEKXw.qkqpy.cn
http://kJJsAciB.qkqpy.cn
http://iYngYRfO.qkqpy.cn
http://tRKxjkKI.qkqpy.cn
http://fTt3o29T.qkqpy.cn
http://gaIgXiV7.qkqpy.cn
http://TGMhlL10.qkqpy.cn
http://e57yXcy7.qkqpy.cn
http://5SYDmmbV.qkqpy.cn
http://gzQYFKE1.qkqpy.cn
http://www.dtcms.com/a/388067.html

相关文章:

  • ubuntu设置ip流程
  • 【论文阅读】谷歌:生成式数据优化,只需请求更好的数据
  • 【深度学习】什么是过拟合,什么是欠拟合?遇到的时候该如何解决该问题?
  • CSA AICM 国际标准:安全、负责任地开发、部署、管理和使用AI技术
  • AI 赋能教育:个性化学习路径设计、教师角色转型与教育公平新机遇
  • 科技为老,服务至心——七彩喜智慧养老的温情答卷
  • ​​[硬件电路-237]:电阻、电容、电感虽均能阻碍电流流动,但它们在阻碍机制、能量转换、相位特性及频率响应方面存在显著差异
  • 内网Windows系统离线安装Git详细步骤
  • @Component 与 @Bean 核心区别
  • Rsync 详解:从入门到实战,掌握 Linux 数据同步与备份的核心工具
  • ffmpeg解复用aac
  • 数据结构--3:LinkedList与链表
  • linx 系统 ffmpeg 推流 rtsp
  • 防水淹厂房监测报警系统的设计原则及主要构成
  • RFID技术赋能工业教学设备教学应用经典!
  • Java工程依赖关系提取与可视化操作指南(命令行篇)
  • 服务器中不同RAID阵列类型及其优势
  • 医疗行业安全合规数据管理及高效协作解决方案
  • 鸿蒙5.0应用开发——V2装饰器@Event的使用
  • logstash同步mysql流水表到es
  • Ground Control-卫星通信 (SATCOM) 和基于蜂窝的无人机和机器人物联网解决方案
  • 计算机视觉技术深度解析:从图像处理到深度学习的完整实战指南
  • 互联网大厂Java面试:从Spring Boot到微服务的实战考验
  • k8s NodePort 30000 - 32767 被用完后该如何处理
  • 高级系统架构师笔记——软件工程基础知识(2)RAD/敏捷模型/CMM/CBSE
  • 【C++】C++类和对象—(中)
  • React 记忆缓存使用
  • 图观 流渲染场景服务编辑器
  • WALL-OSS——点燃QwenVL 2.5在具身空间中的潜力:VL FFN可预测子任务及离散动作token,Action FNN则预测连续动作
  • 设备中断绑定于香港服务器高性能容器的优化方法