当前位置: 首页 > news >正文

7.神经网络基础

7.1 构造网络模型

#自定义自己的网络模块
import torch
from torch import nn
from torch.nn import functional as F
class MLP(nn.Module):def __init__(self):super().__init__()self.hidden=nn.Linear(20,256)self.out=nn.Linear(256,10)#定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self,x):x=self.hidden(x)x=F.relu(x)x=self.out(x)return x
X=torch.normal(mean=0.5,std=2.0,size=(2,20))
net=MLP()
net(X)

7.2 参数管理与初始化

#访问参数
import torch
from torch import nn
net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
X=torch.rand(size=(2,4))
print(net[2].state_dict())
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)
print(net[2].bias.grad)
print(*[(name,param.shape) for name,param in net[0].named_parameters()])
print(*[(name,param.shape) for name,param in net.named_parameters()])
print(net.state_dict()['2.bias'].data)
import torch
from torch import nn
net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
def init_weight(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,mean=0,std=0.01)nn.init.zeros_(m.bias)
def init_constant(m):if type(m)==nn.Linear:nn.init.constant_(m.weight,1)nn.init.zeros_(m.bias)
net.apply(init_weight)
print(net[0].weight.data[0],net[0].bias.data[0])
net.apply(init_constant)
print(net[0].weight.data[0],net[0].bias.data[0])
#参数绑定
shared=nn.Linear(8,8)
net=nn.Sequential(nn.Linear(4,8),nn.ReLU(),shared,nn.ReLU(),shared,nn.ReLU(),nn.Linear(8,1))
net(X)
print(net[2].weight.data[0]==net[4].weight.data[0])
net[2].weight.data[0,0]=100
print(net[2].weight.data[0]==net[4].weight.data[0])

7.3 自定义层

#如何自定义一个层
import torch
from torch import nn
class centeredLayer(nn.Module):def __init__(self):super().__init__()def forward(self,X):return X-X.mean()
layer=centeredLayer()
layer(torch.FloatTensor([1,2,3,4,5]))
net=nn.Sequential(nn.Linear(8,128),centeredLayer())
Y=net(torch.rand(4,8))
Y.mean()

7.4 文件读写与模型保存

#文件读写
import torch
from torch import nn
from torch.nn import functional as F
x=torch.arange(4)
y=torch.zeros(4)
torch.save(x,'x-file')
x2=torch.load('x-file')
print(x2)
torch.save([x,y],'x-files')
x2,y2=torch.load('x-files')
print(x2,y2)
#模型保存
import torch
from torch import nn
from torch.nn import functional as F
class MLP(nn.Module):def __init__(self):super().__init__()self.hidden=nn.Linear(20,256)self.out=nn.Linear(256,10)#定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self,x):x=self.hidden(x)x=F.relu(x)x=self.out(x)return x
X=torch.normal(mean=0.5,std=2.0,size=(2,20))
net=MLP()
Y=net(X)
torch.save(net.state_dict(),'mlp.params')
#读取权重
clone=MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
Y_clone=clone(X)
Y_clone==Y
http://www.dtcms.com/a/270510.html

相关文章:

  • 【JavaEE进阶】图书管理系统(未完待续)
  • 【学习笔记】OkHttp源码架构解析:从设计模式到核心实现
  • 保姆级安装 Ruby 环境下载及安装教程, RubyInstaller下载及安装教程
  • Javaweb - 10.7 乱码和路径问题
  • 影石(insta360)X4运动相机视频删除的恢复方法
  • SHA-256算法详解——Github工程结合示例和动画演示
  • 中望CAD2026亮点速递(5):【相似查找】高效自动化识别定位
  • Python(30)基于itertools生成器的量子计算模拟技术深度解析
  • 【SQL】使用UPDATE修改表字段的时候,遇到1054 或者1064的问题怎么办?
  • (八)PS识别:使用 Python 自动化生成图像PS数据集
  • Linux驱动05 --- TCP 服务器
  • 分库分表之实战-sharding-JDBC绑定表配置实战
  • uniapp+vue3+ts项目:实现小程序文件下载、预览、进度监听(含项目、案例、插件)
  • PostgreSQL如何进行跨服务器迁移数据
  • ARIA UWB安全雷达主要产品型号与核心功能全解析
  • 【数字后端】- Standard Cell Status
  • 亚马逊广告进阶指南:CPC与竞价的底层逻辑
  • 游戏开发学习记录
  • 基于Flask 3.1和Python 3.13的简易CMS
  • LLM中 最后一个词语的表征(隐藏状态)通常会融合前面所有词语的信息吗?
  • Java项目集成Log4j2全攻略
  • 速卖通跨境运营破局:亚矩阵云手机如何用“本地化黑科技”撬动俄罗斯市场25%客单价增长
  • 今日行情明日机会——20250709
  • 伪装计算器软件,隐藏手机隐私文件
  • 3.常⽤控件
  • jmeter做跨线程组
  • 第二章:创建登录页面
  • 函数-3-日期函数
  • Java垃圾收集机制Test1
  • css 设置 input 插入光标样式