当前位置: 首页 > news >正文

《Pytorch深度学习实践》ch5-Logistic回归

                                                        ------B站《刘二大人》

1.Classification

  • 经典的分类数据集:MNIST(0 - 9)

  • 导入数据集:(路径,训练集/测试集,是否下载)
import torchvision
train_set = torchvision.datasets.MINIST(root='../dataset/mnist', train=True,  download=True)
test_set  = torchvision.datasets.MINIST(root='../dataset/mnist', train=False, download=True)

2.Sigmoid functions

  • 由于分类问题就是求概率的最大值,所以利用 S 函数将数值全部映射到 [0,1] 区间;
  • 最著名的就是这个 Logistic 函数:

  • 其它的 S 函数:

3.Logistic Regression Model

  • 就是在原函数基础上加一个 Sigmoid:

4.Loss and BCE

  • BCE:Binary Cross Entropy,二元交叉熵损失:

5.Implemetation

  • 导包:
import torch
import torch.nn.functional as F
  • 数据集:y 变为 {0,1},二分类
# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
  • 模型:F.sigmoid()函数
# 模型
class LogisticRegressionModel(torch.nn.Module): # Module 构建计算图def __init__(self):super(LogisticRegressionModel, self).__init__()self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 前馈y_pred = F.sigmoid(self.linear(x))return y_predmodel = LogisticRegressionModel() # 实例化
  •  损失和优化器:BCELoss
# 损失函数和优化器
criterion = torch.nn.BCELoss(reduction = 'sum') # 计算损失,参数为(y_pred, y)optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 进行更新
  • 训练:
# 训练
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) # 1.前馈print(epoch, loss)optimizer.zero_grad() # 梯度清零loss.backward() # 2.反馈optimizer.step() # 3.更新

6.Result

import numpy as np
import matplotlib.pyplot as pltx = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200,1)) # 将x数组转换为PyTorch张量,并将其形状调整为列向量(200x1)
y_t = model(x_t)
y = y_t.data.numpy() # 将输出张量y_t转换为NumPy数组yplt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r') # 绘制一条从x=0到x=10的红色水平线,y值为0.5
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
  • 绘图如下:


文章转载自:

http://soQtUWsQ.cwrpd.cn
http://McrUGz0E.cwrpd.cn
http://6RYQ1mut.cwrpd.cn
http://ZLHdtQSW.cwrpd.cn
http://wDrURBO7.cwrpd.cn
http://4CgTjyKX.cwrpd.cn
http://ARph4DM8.cwrpd.cn
http://Z5WJj5Ua.cwrpd.cn
http://Ecp67HiF.cwrpd.cn
http://TQHQUriK.cwrpd.cn
http://2w3SsIhW.cwrpd.cn
http://fPl8Njlu.cwrpd.cn
http://xIrV2RAJ.cwrpd.cn
http://3SezVXFN.cwrpd.cn
http://mtcjxKx1.cwrpd.cn
http://OddIxSa3.cwrpd.cn
http://nUuc0IRC.cwrpd.cn
http://xfGjcAEF.cwrpd.cn
http://8n5u5Yi2.cwrpd.cn
http://Ryy4v2eb.cwrpd.cn
http://r6UptYYw.cwrpd.cn
http://6H7Am4pS.cwrpd.cn
http://kLqrCc5w.cwrpd.cn
http://CIdGK4G9.cwrpd.cn
http://nqC72X3y.cwrpd.cn
http://StvDjKxN.cwrpd.cn
http://l3FWyeMc.cwrpd.cn
http://UP4GJY1l.cwrpd.cn
http://oSrPMG0M.cwrpd.cn
http://X6RYDdnp.cwrpd.cn
http://www.dtcms.com/a/229271.html

相关文章:

  • Ubuntu系统安装与配置NTP时间同步服务
  • 邢台山峰特种橡胶制品有限公司专题报道
  • 实战商品订单秒杀设计实现
  • 蜜獾算法(HBA,Honey Badger Algorithm)
  • LangChain核心之Runnable接口底层实现
  • matlab实现掺杂光纤放大器的模拟
  • Termux下如何使用MATLAB
  • GCC内存占用统计使用指南
  • 《深入解析SPI协议及其FPGA高效实现》-- 第三篇:FPGA实现关键技术与优化
  • TCP的粘包和拆包
  • mac环境下的python、pycharm和pip安装使用
  • Linux Maven Install
  • 网络攻防技术八:身份认证与口令攻击
  • Modbus转Ethernet IP赋能挤出吹塑机智能监控
  • OD 算法题 B卷【跳格子2】
  • 飞算 JavaAI 赋能老项目重构:破旧立新的高效利器
  • Go Gin框架深度解析:高性能Web开发实践
  • FLgo学习
  • 【Android】双指旋转手势
  • Lua和JS的继承原理
  • 后台管理系统八股
  • Python应用continue关键字初解
  • 前端验证下跨域问题(npm验证)
  • 隧道监测预警系统:构筑智慧交通的安全中枢
  • 香橙派3B学习笔记6:基本的Bash脚本学习_UTF-8格式问题
  • 定时线程池失效问题引发的思考
  • 前端导入Excel表格
  • 提升系统稳定性和可靠性的特殊线程(看门狗线程)
  • CppCon 2014 学习:Lightning Talk: Writing a Python Interpreter for Fun and Profit
  • 浮点数的位级表示转变为二进制表示