当前位置: 首页 > news >正文

自定义层和读写文件

自定义层

自定义一个没有任何参数的层

import torch
import torch.nn.functional as F
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

将层作为组件和冰岛构建更复杂的模型中

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())Y = net(torch.rand(4m 8))
Y.mean()

带参数的层

class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units))self.bias = nn.Parameter(torch.randn(units,))def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.datareturn F.relu(linear)dense = MyLinear(5, 3)
dense.weight

使用自定义的层执行传播计算

dense(torch.rand(2, 5))

读写文件

import torch
from torch import nn
from torch.nn import functional as Fx = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
x2 == x

存储一个张量列表

y = torch.zeros(4)
torch.save([x, y], 'x-files')
x2, y2 = torch.load('x-files')

写入或读取字典

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')

加载和保存模型参数

class MLP(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)self.output = nn.Linear(256, 10)def forward(self, x):return self.output(F.relu(self.hidden(x)))net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

将模型存储为文件

torch.save(net.state_dict(), 'mlp.params')# 保存参数后需要我们自己保存MLP的定义, 需要有定义才能加载
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
http://www.dtcms.com/a/486189.html

相关文章:

  • SQL Server 2019实验 │ 存储过程和触发器的使用
  • Font Awesome 方向图标详解
  • 在源码之家下载的网站模板可以作为自己的网站吗怎么提升网站流量
  • MySQL客服端工具
  • ElementUi【饿了么ui】
  • 五点法求解相机的相对位姿
  • 外贸网站推广工作哈尔滨建站
  • 网站右边跳出的广告怎么做dw网站建设基本流程
  • Excelize 开源基础库发布 2.10.0 版本更新
  • iOS 26 系统流畅度测试实战分享,多工具组合辅助策略
  • 智尚房产中介小程序
  • Kuboard突然各种proxy访问401解决
  • 自己做卖假货网站小程序怎么制作开发
  • 专业网站优化山西省城乡住房建设厅网站首页
  • 后端Node知识框架图(Node、Express、KoaNest)
  • 数据结构3:线性表2-顺序存储的线性表
  • TaskIQ 是什么,怎么做异步任务
  • 服务器CPU达到100%解决思路
  • 在 Claude Code 中设置 MCP 服务器(技术总结)
  • 网站上传根目录如何制作线上投票
  • 移动端网站建设的请示东莞科技网站建设
  • EtherCAT转CCLKIE工业通讯网关突破:三菱PLC实时调度EtherCAT伺服完成精密加工
  • 深度学习实验一之图像特征提取和深度学习训练数据标注
  • 基于Matlab的深度堆叠自编码器(SAE)实现与分类应用
  • @Scope失效问题
  • Service 网络原理
  • 数据复制问题及其解决方案
  • Java-Spring入门指南(二十五)Android 的历史,认识移动应用和Android 基础知识
  • WPF依赖属性(Dependency Property)详解
  • 深度学习进阶(三)——生成模型的崛起:从自回归到扩散