当前位置: 首页 > news >正文

2.搭建Pytorch神经网络进行气温预测

回归任务:输入数据—网络—预测值

1. 数据处理

在这里插入图片描述
这个任务就是输入一个天气的特征,预测实际的温度。
在这里插入图片描述
将星期几这种字符串转换成01:
在这里插入图片描述
在这里插入图片描述
348个样本,每个样本14个特征,但是当前特征有的数据大,有的很小。不能越大的数就越重要,需要进行标准化。

  1. 去均值–尽可能让数据按原点对称
    在这里插入图片描述
  2. 比标准差–取值范围大的除大的,小的除小的
    在这里插入图片描述
    这样就比较均衡。
    fit就是在计算并做转换
    在这里插入图片描述
    如果我们的第一个year全是2016,那这个特征就没用了,所以是0。

2. 构建网络模型

把输入输出数据都得转化成tensor
在这里插入图片描述
randn 做一个随机初始化,都要算梯度
在这里插入图片描述
指定学习率(相对小一点),损失
在这里插入图片描述
总共迭代1000次,每次都更新梯度
在这里插入图片描述
x.mm(乘)w1+b1 得到隐层结果,然后引入一个非线性映射relu;然后算输出 h.mm(w2) + b2 ,然后评估一下预测的好不好,也就是计算损失。

因为式子中的数据都是tensor,所以计算出的 loss 的格式是tensor,但是matplotlib画图用的数据是ndarray格式,所以需要转换一下。
在这里插入图片描述
迭代100次打印一下损失。反向传播虽然原理复杂,但是代码就这么一行就搞定了。

反向完取梯度,沿这个方向去做更新,当然是根据学习率去做更新,负号就是因为这个梯度是向上的,我们想要沿着梯度的反方向去更新,要去找能使损失最小的权重参数。
在这里插入图片描述
这个就是上一节中也讲过,每次迭代都是做同样的事,和之前的迭代一点关系没有,所以需要清空,torch中不清零会累加。

以上这些就是流程,实际中有封装好的接口直接调用。
在这里插入图片描述
在这里插入图片描述
348个预测值,每个样本一个预测结果。

3. 更简单的构建网络模型

调包需要一些参数:
在这里插入图片描述
input_size:输入特征个数
batch_size:一次迭代这么多个样本,一批一批去迭代

nn指定一个 Sequential 序列模块,就是下面的按顺序执行。
Linear:wx+b,不用我们自己写那些过程了
Sigmoid:也是一种激活函数,用哪个都行
Linear:做输出

刚才手打公式,现在就是用人家的损失函数:MSELoss(均值)计算出cost
优化器optimizer:用学习率更新参数

简单讲讲Adam:它的引用量非常之高,已经成为深度学习社区的“常青树”级方法。
在这里插入图片描述
蓝色曲线 就是典型的“狭长山谷”型损失面,红色点为最优损失点,目标就是最快的到达那个点。

这样的场景下,不同维度的梯度尺度差异 会导致普通 SGD 在陡峭方向上来回震荡,而在平缓方向上前进缓慢。蓝色折线就是每一步只按当前梯度(竖直向下)更新后的位移。它每次都在「竖直方向」来回猛跳,没什么推进效果;在「横轴」方向上,更新幅度又很小,进展极其缓慢。

红色虚线箭头就是 累加历史梯度的动量(惯性),在陡峭方向上,上一次的梯度会部分抵消本次的梯度抑制来回震荡。在平缓方向上,累积的小梯度则能「滚雪球」式地加速前进。

绿色轨迹比蓝色「上下颠簸」平滑了许多,整体朝最深处(红点)更快滑过去。

总的来说,Adam和SGD的基本思想一样,但是效率更高。想都不用想直接调Adam就完事了。
在这里插入图片描述
还是迭代1000次,每次取的xx,yy是一批一批的取,也就是一个batch一个batch的取。

xx按照我们的网络一点一点去走,计算loss,opt.zero_grad() 上轮梯度清零,然后反向传播,step对所有参数进行更新,存一下损失就完事了。

只有调用 loss.backward() 时才会往各个参数的 .grad 里写入(其实是“累加”)梯度。只要你能保证每次调用 backward() 前,那些 .grad 都是 0,就不会把「新梯度」和「上一次残留的旧梯度」搞混了。

4. 预测训练结果

上面是训练网络的过程,现在就是用这个训练好的去进行预测,看看结果
在这里插入图片描述
predict就是最终预测值
在这里插入图片描述
在这里插入图片描述
我们这个中间只有一个隐层,这个可以加着玩,越加越过拟合。

相关文章:

  • 数据湖 vs 数据仓库:数据界的“自来水厂”与“瓶装水厂”?
  • 表达式求值
  • Launcher3中的CellLayout 和ShortcutAndWidgetContainer 的联系和各自职责
  • 华为云镜像仓库下载 selenium/standalone-chrome 镜像
  • SQL关键字三分钟入门:ROW_NUMBER() —— 窗口函数为每一行编号
  • 深度学习-分类
  • Sensodrive SensoJoint机器人力控关节模组抗振动+Sensodrive力反馈系统精准对接
  • web3 docs
  • 力扣第73题-矩阵置零
  • Java面向对象(一)
  • 对话式数据分析与Text2SQL Agent产品可行性分析思考
  • Python 数据分析:numpy,抽提,整数数组索引
  • 从单体架构到微服务:微服务架构演进与实践
  • 如何解决电脑windows蓝屏问题
  • 叉车考试真题(含答案)pdf下载
  • Rust宏和普通函数的区别
  • 心理测评app在线预约系统框架设计
  • 【HarmonyOS Next之旅】DevEco Studio使用指南(三十八) -> 构建HAR
  • ByteMD+CozeAPI+Coze平台Agent+Next搭建AI辅助博客撰写平台(逻辑清楚,推荐!)
  • 如何修改discuz文章标题字数限制 修改成255