当前位置: 首页 > news >正文

【PyTorch】深度学习实践——第二章:线性模型

参考:刘二老师的《PyTorch深度学习实践》完结合集

本章实现了一个简单的线性回归模型,用于学习输入x和输出y之间的线性关系(y=w*x)。

一、代码细节

1.数据准备

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
  • 定义了训练数据,x和y之间显然是y=2x的关系,只是我们自己知道计算机不知道。

2.模型定义

def forward(x):return x * w
  • 非常简单的线性模型,只有一个权重参数w

3.损失函数

def loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)
  • 使用均方误差(MSE)作为损失函数

4.训练循环

for w in np.arange(0.0, 4.1, 0.1):print('w=',w)l_sum = 0for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)loss_val = loss(x_val, y_val)l_sum += loss_valprint("MSE", l_sum / 3)w_list.append(w)mse_list.append(l_sum / 3)
  • 遍历w的可能值(0.0到4.0,步长0.1)
  • 对每个w值,计算在所有训练数据上的总损失
  • 计算并存储平均MSE

5.可视化

plt.plot(w_list, mse_list)
plt.xlabel('w')
plt.ylabel('Loss')
plt.show()

6.找最优解

min_mse = min(mse_list)
optimal_w = w_list[mse_list.index(min_mse)]
print(f"\nOptimal weight: {optimal_w:.1f} (MSE = {min_mse:.2f})")

二、完整代码

import numpy as np
import matplotlib.pyplot as plt# 训练数据
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]# 前向传播函数
def forward(x):return x * w# 损失函数
def loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)# 存储权重和对应的MSE值
w_list = []
mse_list = []# 遍历不同的权重值
for w in np.arange(0.0, 4.1, 0.1):print("w =", w)l_sum = 0  # 累计损失# 计算当前权重下的预测值和损失for x_val, y_val in zip(x_data, y_data):y_pred_val = forward(x_val)loss_val = loss(x_val, y_val)l_sum += loss_valprint("\t", x_val, y_val, y_pred_val, loss_val)# 计算并存储平均MSEprint("MSE:", l_sum / 3)w_list.append(w)mse_list.append(l_sum / 3)# 可视化结果
plt.plot(w_list, mse_list)
plt.title('Loss for different weights')
plt.xlabel('w')
plt.ylabel('Loss')
plt.show()# 找到最优权重
min_mse = min(mse_list)
optimal_w = w_list[mse_list.index(min_mse)]
print(f"\nOptimal weight: {optimal_w:.1f} (MSE = {min_mse:.2f})")

相关文章:

  • LVGL输入设备管理
  • Dinky 安装部署并配置提交 Flink Yarn 任务
  • 11. CSS从基础样式到盒模型与形状绘制
  • C++学习之路,从0到精通的征途:继承
  • 基于脑功能连接组和结构连接组的可解释特定模态及交互图卷积网络|文献速递-深度学习医疗AI最新文献
  • 在虚拟机Ubuntu18.04中安装NS2教程及应用
  • 大白话解释联邦学习
  • hadoop3.x单机部署
  • Mysql索引优化
  • Spring Boot之Web服务器的启动流程分析
  • 【android bluetooth 框架分析 02】【Module详解 7】【VendorSpecificEventManager 模块介绍】
  • 使用光标测量,使用 TDR 测量 pH 和 fF
  • AI 模型训练轻量化技术在军事领域的实战应用与技术解析
  • ​​华为云服务器:智能算力网格​
  • Vue 3.5 新特性深度解析:全面升级的开发体验
  • MQTT协议详解:物联网通信的轻量级解决方案
  • idea Maven 打包SpringBoot可执行的jar包
  • 【YOLO模型】参数全面解读
  • 微信小程序 密码框改为text后不可见,需要点击一下
  • uni-app学习笔记五-vue3响应式基础
  • 美国4月CPI同比上涨2.3%低于预期,为2021年2月来最小涨幅
  • 缺字危机:一本书背后有多少“不存在”的汉字?
  • 飙升至熔断,巴基斯坦股市两大股指收盘涨逾9%
  • 最高降价三成,苹果中国iPhone开启大促销,能拉动多少销量?
  • IPO周报|本周A股暂无新股网上申购,年内最低价股周二上市
  • 第一集丨《亲爱的仇敌》和《姜颂》,都有耐人寻味的“她”