当前位置: 首页 > news >正文

Day26_【深度学习(6)—神经网络NN(1.2)前向传播的搭建案例】

自定义神经网络——前向传播

import torch
import torch.nn as nn
from torchsummary import summary
'''自定义神经网络步骤:1.继承2.实现两个方法__init__、forward()'''class Net(nn.Module):def __init__(self):super().__init__()  # 继承父类init# super(Net, self).__init__() 与上一行代码同作用# 定义隐藏层# 隐藏层1  初始化参数 一般只初始化w,不用初始化b(bias会默认初始)self.linear1 = nn.Linear(in_features=3, out_features=3)  # 输入特征数,输出特征数nn.init.xavier_uniform_(self.linear1.weight)nn.init.ones_(self.linear1.bias) # 练习写# 隐藏层2self.linear2 = nn.Linear(in_features=3, out_features=2)  # 输入特征数,输出特征数nn.init.kaiming_normal_(self.linear2.weight)nn.init.ones_(self.linear2.bias) # 练习写# 定义输出层self.out = nn.Linear(in_features=2, out_features=2)# 创建前向传播方法, 调用神经网络模型对象时自动执行forward()方法def forward(self, x):# 数据经过第一个线性层,使用sigmoid激活函数x = torch.sigmoid(self.linear1(x))# 数据经过第二个线性层,使用relu激活函数x = torch.relu(self.linear2(x))# 数据经过输出层,使用softmax激活函数# dim=-1:每一维度行数据相加为1x = torch.softmax(self.out(x), dim=-1)return xdef test_my_model():model = Net()print(model)#准备数据input=torch.randn(2,3) #两行三列,三个特征print(input.shape)# 给摸行喂数据output=model(input)print(output.shape)# 计算模型参数summary(model, input_size=(3,))# 查看模型参数for name, parameter in model.named_parameters():# print('name--->', name)# print('parameter--->', parameter)print(name, parameter)if __name__ == '__main__':test_my_model()

http://www.dtcms.com/a/390080.html

相关文章:

  • 河南省 ERA5 气象数据处理教程(2020–2025 每月均值)
  • IIS短文件漏洞修复全攻略
  • jdk-7u25-linux-x64.tar.gz 安装教程(Linux下JDK 7 64位解压配置详细步骤附安装包)
  • 边界值分析法的测试用例数量:一般边界值分析(4n+1)和健壮性测试(6n+1)计算依据
  • 基于飞算AI的图书管理系统设计与实现
  • Day26_【深度学习(6)—神经网络NN(1)重点概念浓缩、前向传播】
  • 软考 系统架构设计师系列知识点之杂项集萃(151)
  • Python基础 2》运算符
  • docker 部署 sftp
  • 数字ic笔试
  • 武汉火影数字|数字展厅设计制作:多媒体数字内容打造
  • LLM模型的参数量估计
  • STM32H743-学习HAL库
  • 一键防范假票入账-发票识别接口-发票查验接口-信息提取
  • RTEMS 控制台驱动
  • flutter在列表页面中通过监听列表滑动偏移量控制页面中某个控件的透明度
  • linux上升级nginx版本
  • WINCC结构变量/公共弹窗
  • 信息化项目验收计划方案书
  • 1.数据库概述和三种主要控制语言
  • 找到nohup启动的程序并杀死
  • 电磁干扰EMI (Electromagnetic Interference)是什么?
  • python提取域名
  • PR工具timing report中setup time的计算过程
  • 低延迟垃圾收集器:挑战“不可能三角”
  • 【测试】发版测试准入准出标准
  • 第一部分:HTML
  • 贪心算法应用:带权任务间隔调度问题详解
  • 视频监控大数据建模分析
  • IP的重要性