当前位置: 首页 > news >正文

自定义简单线性回归模型

自定义实现

import randomimport torchdef synthetic_data(w, b, num_examples):"""生成y=Xw+b+噪声 数据X: 代表原始的样本y: 代表特征值或者结果值"""# 定义X的输入区间 0 ~ 1, 列数是w的维数X = torch.normal(0, 1, (num_examples, len(w)))# 定义 yy = torch.matmul(X, w) + b# 增加额外噪声y += torch.normal(0, 0.01, y.shape)# 重新改变y的形状 2 * 1000, 1 * 1000return X, y.reshape((-1, 1))def data_iter(batch_size, features, labels):"""定义迭代器, 支持训练过程每个epochs"""num_exaples = len(features)indices = list(range(num_exaples))random.shuffle(indices)for i in range(0, num_exaples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_exaples)])yield features[batch_indices], labels[batch_indices]def linreg(X, w, b):"""定义线性模型"""return torch.matmul(X, w) + bdef squared_loss(y_hat, y):"""定义损失函数, 均方差损失函数"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def sgd(params, lr, batch_size):"""定义优化方法, 小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()if __name__ == '__main__':true_w = torch.tensor([2, -3.4])true_b = 4.2batch_size = 10features, labels = synthetic_data(true_w, true_b, 1000)# print(features)# print(labels)# 生成随机的两行一列的, 0~0.01的值w = torch.normal(0, 0.01, (2, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)# 步长lr = 0.0001num_epochs = 3000# 定义网络, 通过什么样的网络进行数据的拟合net = linreg# loss 是为了计算预测值和真实值之间的关系loss = squared_loss# 定义优化方法和参数更新方式sgd = sgdfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):# X 和 y 的小批量损失l = loss(net(X, w, b), y)# 因为l形状是(batch_size, 1), 不是一个标量. l中所有元素被加到一起, 计算关于[w,b]的梯度l.sum().backward()# 使用参数的梯度更新参数sgd([w, b], lr, batch_size)# 不计算梯度, 查看和真实值的差异with torch.no_grad():train_l = loss(net(features, w, b), labels)print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))print("w的估计误差: ", true_w - w.reshape(true_w.shape))print("b的估计误差: ", true_b - b)

简洁实现

import randomimport torch
from torch import nndef synthetic_data(w, b, num_examples):"""生成y=Xw+b+噪声 数据X: 代表原始的样本y: 代表特征值或者结果值"""# 定义X的输入区间 0 ~ 1, 列数是w的维数X = torch.normal(0, 1, (num_examples, len(w)))# 定义 yy = torch.matmul(X, w) + b# 增加额外噪声y += torch.normal(0, 0.01, y.shape)# 重新改变y的形状 2 * 1000, 1 * 1000return X, y.reshape((-1, 1))def load_array(data_arrays, batch_size, is_train=True):"""构造一个数据迭代器"""dataset = torch.utils.data.TensorDataset(*data_arrays)return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)if __name__ == '__main__':true_w = torch.tensor([2, -3.4])true_b = 4.2batch_size = 10features, labels = synthetic_data(true_w, true_b, 1000)# print(features)# print(labels)# 生成随机的两行一列的, 0~0.01的值w = torch.normal(0, 0.01, (2, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)data_iter = load_array((features, labels), batch_size)# 步长lr = 0.003# 定义网络, 通过什么样的网络进行数据的拟合net = nn.Sequential(nn.Linear(2, 1))# 初始化模型参数net[0].weight.data.normal_(0, 0.01)net[0].bias.data.fill_(0)# loss 是为了计算预测值和真实值之间的关系loss = nn.MSELoss()# 定义优化方法和参数更新方式sgd = torch.optim.SGD(net.parameters(), lr=0.03)num_epochs = 3for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)sgd.zero_grad()l.backward()sgd.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')w = net[0].weight.dataprint('w的估计误差:', true_w - w.reshape(true_w.shape))b = net[0].bias.dataprint('b的估计误差:', true_b - b)
http://www.dtcms.com/a/268604.html

相关文章:

  • 【AI大模型】神经网络反向传播:核心原理与完整实现
  • 电脑电压过高的影响与风险分析
  • 轨迹优化 | 基于激光雷达的欧氏距离场ESDF地图构建(附ROS C++仿真)
  • 回溯题解——子集【LeetCode】二进制枚举法
  • ssh: Could not resolve hostname d: Temporary failure in name resolution
  • 从依赖地狱到依赖天堂PNPM
  • 01、通过内网穿透工具把家中闲置电脑变成在线服务器
  • C盘瘦身 -- 虚拟内存文件 pagefile.sys
  • (六)PS识别:源数据分析- 挖掘图像的 “元语言”技术实现
  • python list去重
  • 【Behavior Tree】-- 行为树AI逻辑实现- Unity 游戏引擎实现
  • Docker 将镜像打成压缩包将压缩包传到服务器运行
  • 物联网技术的关键技术与区块链发展趋势的深度融合分析
  • Java SE与Java EE使用方法及组件封装指南
  • 安卓10.0系统修改定制化_____安卓9与安卓10系统文件差异 有关定制选项修改差异
  • Java 并发编程中的同步工具类全面解析
  • qiankun隔离机制
  • [附源码+数据库+毕业论文]基于Spring+MyBatis+MySQL+Maven+jsp实现的高校实验室资源综合管理系统,推荐!
  • 按键开关:新型防水按键开关的特点!
  • 音频流媒体技术选型指南:从PCM到Opus的实战经验
  • 【Java面试】Https和Http的区别?以及分别的原理是什么?
  • 02 除了前面常见图表,还有许多更细分或专业的可视化类型,尤其是在特定领域(如金融、工程、生物信息等)。
  • GaussDB应用场景全景解析:从金融核心到物联网的分布式数据库实践
  • OpenCV 人脸分析----人脸识别的一个经典类cv::face::EigenFaceRecognizer
  • Oracle PL/SQL 编程基础详解(从块结构到游标操作)
  • idea 使用vscode 快捷键
  • UE 材质 变体 概念
  • ClickHouse 入门详解:它到底是什么、优缺点、和主流数据库对比、适合哪些场景?
  • 1.1_5_2 计算机网络的性能指标(下)
  • 【Vben3全解】【组件库开发】解决组件库开发中css的命名难题,保证代码质量,构建useNamespace函数