当前位置: 首页 > news >正文

Class9简洁实现

Class9简洁实现

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
# 初始化训练样本、测试样本、样本特征维度和批量大小
n_train,n_test,num_inputs,batch_size = 20,100,200,5
# 设置真实权重和偏置
true_w,true_b = torch.ones((num_inputs,1)) * 0.01,0.05
# 生成训练数据
# d2l.synthetic_data():函数生成模拟的训练数据
# synthetic_data()L返回三元组(features,labels)
train_data = d2l.synthetic_data(true_w,true_b,n_train)
# 数据封装为训练数据迭代器
# d2l.load_array():把数据打包成一个笑屁刘昂迭代器,便于后续训练
# batch_size=5:每次迭代返回5个样本
train_iter = d2l.load_array(train_data,batch_size)
# 生成测试数据
test_data = d2l.synthetic_data(true_w,true_b,n_test)
# 数据封装为测试数据迭代器
test_iter = d2l.load_array(test_data,batch_size,is_train=False)
# 实现带权重衰减(L2正则)线性回归模型训练
# wd:L2正则化系数lambd
def train_concise(wd):# 构建一个全连接层,输入为num_inputs,输出为1net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():# 将参数用正态分布随机初始化param.data.normal_()# 样本的均方误差不求平均loss = nn.MSELoss(reduction='none')# 定义训练轮数和学习率num_epochs, lr = 100, 0.003# 使用随机梯度下降优化器trainer = torch.optim.SGD([# 权重参数,应用L2正则{"params":net[0].weight,'weight_decay': wd},# 偏置参数,不加正则{"params":net[0].bias}], lr=lr)# 定义可视化工具animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])# 循环训练for epoch in range(num_epochs):for X, y in train_iter:# 清空梯度,防止梯度累加trainer.zero_grad()# 计算每个样本的MSELossl = loss(net(X), y)# 进行反向传播l.mean().backward()# 更新模型参数trainer.step()# 每5轮评估训练集和测试集的loss损失函数if (epoch + 1) % 5 == 0:# 将当前loss加入到动态图中animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))# 打印输出L2范数print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
train_concise(3)
http://www.dtcms.com/a/279852.html

相关文章:

  • HashMap的put过程以及hashMap的简单介绍
  • kt 中商品的金额字段使用double 还是 bigdecimal
  • 动态规划题解——最长递增子序列【LeetCode】记忆化搜索方法
  • 【每日刷题】杨辉三角
  • Git根据标签Tag强制回滚版本
  • 面试常问:如何在一个长度为n的无序数据快速获取前k个数值
  • 网络传输过程
  • GaussDB between的用法
  • 光伏板如何最大化铺设?
  • 【PostgreSQL异常解决】`PostgreSQL`异常之类型转换错误
  • 记录自己在将python文件变成可访问库文件是碰到的问题
  • vert.x 官网docs, vert.x中文文档地址 vertx文档
  • 文心4.5开源之路:引领技术开放新时代!
  • 【前端:Typst】--let关键字的用法
  • 高德开放平台携手阿里云,面向开发者推出地图服务产品MCP Server
  • 外部协作不力影响项目进度,如何加强外部沟通
  • 项目进度压缩影响质量,如何平衡进度与质量
  • LeetCode|Day11|557. 反转字符串中的单词 III|Python刷题笔记
  • 稀土化合物:助力高效种植与健康养殖
  • vue笔记3 VueRouter VueX详细讲解
  • 对象的使用
  • CAN终端电阻为什么是60R+60R,而不直接用120R?
  • 前端vue对接海康摄像头流程
  • Flink窗口处理函数
  • C++-linux 5.gdb调试工具
  • 【从语言幻觉看趋势】从语言幻觉到多智能体协作:GPT多角色系统的技术演进与实践路径
  • 判断端口处于监听状态的方法
  • 腾讯云WAF域名分级防护实战笔记
  • EPLAN 电气制图(八):宏应用与变频器控制回路绘制全攻略
  • ssm学习笔记day07mybatis