当前位置: 首页 > news >正文

Class57 代码实现

import torch
from torch import nn
from d2l import torch as d2l
# 设置训练样本数量
n_train = 50
# 在[0,5]之间随机生成50个浮点数
x_train,_ = torch.sort(torch.rand(n_train) * 5)# 定义目标函数
def f(x):# 定义非线性函数return 2 * torch.sin(x) + x ** 0.8# 生成长度为50的高斯噪声,均值为0,标准差为0.5
y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,))
# 在[0,5]区间上每隔0.1取一个点
x_test = torch.arange(0,5,0.1)
# 测试数据在真实函数上的输出
y_truth = f(x_test)
# 计算测试集样本数量
n_test = len(x_test)
# 打印结果
n_test
# 定义绘图函数
def plot_kernel_reg(y_hat):# 定义横坐标,纵坐标,坐标标签,图例,限制坐标范围d2l.plot(x_test,[y_truth,y_hat],'x','y',legend=['Truth','pred'],xlim=[0,5],ylim=[-1,5])# 将训练样本用原点表示画出来d2l.plt.plot(x_train,y_train,'o',alpha=0.5);# 把均值复制n_test次,得到长度等于测试集的张量
y_hat = torch.repeat_interleave(y_train.mean(),n_test)
# 绘制结果
plot_kernel_reg(y_hat)
# 非参数注意力池化
# 把每个点重复n_train次,得到形状(n_test,n_train)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1,n_train))
# 计算注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train) ** 2 / 2,dim=1)
# 预测y值
y_hat = torch.matmul(attention_weights,y_train)
# 进行可视化
plot_kernel_reg(y_hat)
# 在指定维度上插入一个大小为1的维度
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),# 横坐标标记训练输入xlabel='Sorted training inputs',# 纵坐标标记测试输入ylabel='Sorted testing inputs')
# 2个batch,每个batch存1*4的矩阵
X = torch.ones((2,1,4))
# 2个batch,每个batch存4*6的矩阵
Y = torch.ones((2,4,6))
# 计算批量矩阵乘法
torch.bmm(X,Y).shape
# 定义2个batch,每个batch有10个权重
weights = torch.ones((2,10)) * 0.1
# 定义2个batch,每个batch有10个数
values = torch.arange(20.0).reshape((2,10))
# 批量矩阵相乘
torch.bmm(weights.unsqueeze(1),values.unsqueeze(-1))
class NWKernelRegression(nn.Module):# 初始化函数def __init__(self,**kwargs):super().__init__(**kwargs)# 定义可学习参数w,形状为(1,),初始值是[0,1)之间的随机数self.w = nn.Parameter(torch.rand((1,),requires_grad=True))# 前向传播方法   # queries:要预测的位置# keys:训练数据位置# values:训练数据对应的值def forward(self,queries,keys,values):# 对每个query和keys做一对一差值计算queries = queries.repeat_interleave(keys.shape[1]).reshape((-1,keys.shape[1]))# 计算注意力权重self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w) ** 2/2 ,dim=1)# 对值做加权和return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)
# 重复训练数据
X_title = x_train.repeat((n_train,1))
Y_title = y_train.repeat((n_train,1))
# 构造mask排除自身
keys = X_title[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))
values = Y_title[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))
# 初始化模型
net = NWKernelRegression()
# 初始化损失
loss = nn.MSELoss(reduction='none')
# 初始化优化器
trainer = torch.optim.SGD(net.parameters(),lr=0.5)
# 绘制损失曲线
animator = d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[1,5])# 循环5次
for epoch in range(5):# 梯度清零trainer.zero_grad()# 前向传播计算预测值l = loss(net(x_train,keys,values),y_train)# 反向传播l.sum().backward()# 参数更新trainer.step()# 打印输出print(f'epoch{epoch + 1},loss{float(l.sum()):.6f}')# 绘制训练曲线animator.add(epoch + 1,float(l.sum()))
# 构造keys
keys = x_train.repeat((n_test,1))
# 构造values
values = y_train.repeat((n_test,1))
# 进行模型预测
y_hat = net(x_test,keys,values).unsqueeze(1).detach()
# 绘制预测曲线
plot_kernel_reg(y_hat)
# 绘制热力图,增加batch和channel两个维度
d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),# x轴对应训练样本xlabel='Sorted training inputs',# y轴对应测试样本ylabel='Sorted testing inputs')

文章转载自:

http://6aRim0Wq.LhxkL.cn
http://EDrpBMQ9.LhxkL.cn
http://Ntn8mD9R.LhxkL.cn
http://wLwLbDfu.LhxkL.cn
http://J8qmdP5A.LhxkL.cn
http://3B7645cK.LhxkL.cn
http://TB9khBrh.LhxkL.cn
http://03KlolZN.LhxkL.cn
http://w1z0J0L2.LhxkL.cn
http://ByeitCbr.LhxkL.cn
http://N6NOfC1A.LhxkL.cn
http://cfilF9au.LhxkL.cn
http://TB51OzW8.LhxkL.cn
http://XAiHjfkN.LhxkL.cn
http://CF33Ugzn.LhxkL.cn
http://U19ulHly.LhxkL.cn
http://QUZpV3A0.LhxkL.cn
http://eoD6ed54.LhxkL.cn
http://ZMTykTzb.LhxkL.cn
http://nyQV5N7w.LhxkL.cn
http://SzloqP6K.LhxkL.cn
http://HD8VxSNM.LhxkL.cn
http://ddXTATtY.LhxkL.cn
http://sUIWaTSA.LhxkL.cn
http://jn1b1c1y.LhxkL.cn
http://oQXwboCK.LhxkL.cn
http://CRudubRP.LhxkL.cn
http://TxejdRx5.LhxkL.cn
http://CcodJl1N.LhxkL.cn
http://H9RFF8hu.LhxkL.cn
http://www.dtcms.com/a/386212.html

相关文章:

  • torch.gather
  • 自学嵌入式第四十二天:单片机-定时器和UART串口
  • 大数据毕业设计选题推荐-基于大数据的旅游网站用户行为数据分析系统-Hadoop-Spark-数据可视化-BigData
  • 深入浅出数据结构:队列(Queue)—— 生活中的排队艺术
  • spring通过Spring Integration实现udp通信
  • Linux内存管理章节十八:内核开发者的武器库:内存分配API实战指南
  • CAD如何输出PDF多页文件
  • 我对 WPF 动摇时的选择:.NET Framework 4.6.2+WPF+Islands+UWP+CompostionApi
  • 1.整流-滤波电路的缺点和PFC的引入
  • QT 项目 线程信号切换 举例
  • 构网型5MW中压储能变流升压一体机技术方案
  • 【数据工程】8. SQL 入门教程
  • C++---前向声明
  • 在Qt项目中使用QtConcurrent::run,实现异步等待和同步调用
  • 经验分享只靠口头传递会带来哪些问题
  • Linux底层-内核数据接口:/proc
  • PEFT+DeepSpeed 1 (微调 分布式 显存优化)
  • Spring Boot 下 Druid 连接池:多维度优化打造卓越性能
  • 提升学术研究能力:从开题构思难题到AI辅助提纲生成
  • spring-kafka的消息拦截器RecordInterceptor
  • VSCode + Python 开发踩坑:虚拟环境不在项目根目录导致包无法识别该怎么办
  • 【MCP】【FastMCP】[特殊字符] 使用 UV 创建 FastMCP 服务完整示例
  • 蓝绿部署(Blue-Green Deployment)介绍(一种用于降低软件发布风险的部署策略)流量切换(金丝雀发布)
  • 羽毛球地板:从专业运动场景到全民健身市场的技术跃迁与产业重构
  • 【实战】预警算法--噪声添加机制
  • Three.js 中如何给 3D 模型添加文字标签?
  • 贪心算法应用:NFV功能部署问题详解
  • 第八章:Jmeter 非GUl命令详解
  • 知识点17:多Agent系统架构设计模式
  • 作为学术工作者,利用沁言学术提升效率:集成化与一站式体验