当前位置: 首页 > news >正文

Day26_【深度学习(6)—神经网络NN(2)前向传播的搭建案例】

自定义神经网络——前向传播

import torch
import torch.nn as nn
from torchsummary import summary
'''自定义神经网络步骤:1.继承2.实现两个方法__init__、forward()'''class Net(nn.Module):def __init__(self):super().__init__()  # 继承父类init# super(Net, self).__init__() 与上一行代码同作用# 定义隐藏层# 隐藏层1  初始化参数 一般只初始化w,不用初始化b(bias会默认初始)self.linear1 = nn.Linear(in_features=3, out_features=3)  # 输入特征数,输出特征数nn.init.xavier_uniform_(self.linear1.weight)nn.init.ones_(self.linear1.bias) # 练习写# 隐藏层2self.linear2 = nn.Linear(in_features=3, out_features=2)  # 输入特征数,输出特征数nn.init.kaiming_normal_(self.linear2.weight)nn.init.ones_(self.linear2.bias) # 练习写# 定义输出层self.out = nn.Linear(in_features=2, out_features=2)# 创建前向传播方法, 调用神经网络模型对象时自动执行forward()方法def forward(self, x):# 数据经过第一个线性层,使用sigmoid激活函数x = torch.sigmoid(self.linear1(x))# 数据经过第二个线性层,使用relu激活函数x = torch.relu(self.linear2(x))# 数据经过输出层,使用softmax激活函数# dim=-1:每一维度行数据相加为1x = torch.softmax(self.out(x), dim=-1)return xdef test_my_model():model = Net()print(model)#准备数据input=torch.randn(2,3) #两行三列,三个特征print(input.shape)# 给摸行喂数据output=model(input)print(output.shape)# 计算模型参数summary(model, input_size=(3,))# 查看模型参数for name, parameter in model.named_parameters():# print('name--->', name)# print('parameter--->', parameter)print(name, parameter)if __name__ == '__main__':test_my_model()

http://www.dtcms.com/a/388840.html

相关文章:

  • 古老的游戏之竞技体育
  • CURSOR平替(deepseek+VScode)方案实现自动化编程
  • java对电子发票是否原件的快速检查
  • 贪心算法应用:顶点覆盖问题详解
  • Odoo中非库存商品的高级自动化采购工作流程
  • 缺少自动化测试会对 DevOps 带来哪些风险
  • 深入解析 Python 中的 __pycache__与字节码编译机制
  • SEO 优化:元数据 (Metadata) API 和站点地图 (Sitemap) 生成
  • postman+Jenkins进行API automation集成
  • 【算法磨剑:用 C++ 思考的艺术・单源最短路收官】BF/SPFA 负环判断模板 + 四大算法全总结
  • Flink的介绍及应用
  • 微信小程序插屏广告(InterstitialAd)全解析与实战应用案例
  • 格雷希尔G70R系列快速密封连接器+GT系列软管组件的配套组合方案,在新能源汽车老化测试的应用
  • 【Debug日志| 随机下降】
  • 滑动窗口法的优化与实战——力扣209.长度最小的子数组
  • 【Spring Boot 报错已解决】org.yaml.snakeyaml.scanner.ScannerException 报错原因与解决方案
  • 国家统计局数据读取——数据读取——清洗数据06
  • 基于 scratch 构建简单镜像
  • Web安全的暗角:10大易忽略逻辑漏洞解析!
  • 矩阵奇异值分解算法(SVD)详解
  • 【FreeRTOS】 二值信号量与互斥量(CMSIS-RTOS v2 版本)
  • Qt C++ :Qt全局定义<QtGlobal>
  • 【STL源码剖析】从源码看 list:从迭代器到算法
  • MySQL 专题(三):事务与锁机制深度解析
  • 使用BLIP训练自己的数据集(图文描述)
  • Geoserver修行记--在geoserver中如何复制某个图层组内容
  • DBG数据库透明加密网关:SQLServer应用免改造的安全防护方案,不限制开发语言的加密网关
  • 不同上位开发语言、PLC下位平台、工业协议与操作系统平台下的数据类型通用性与差异性详解
  • 【入门篇|第二篇】从零实现选择、冒泡、插入排序(含对数器)
  • javaweb Servlet基本介绍及开发流程