当前位置: 首页 > news >正文

基础神经网络模型搭建

nn 包提供通用深度学习网络的模块集合,接收输入张量,计算输出张量,并保存权重。通常使用两种途径搭建 PyTorch 中的模型:nn.Sequential和 nn.Module。

nn.Sequential通过线性层有序组合搭建模型;nn.Module通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

目录

搭建线性层

通过nn.Sequential搭建

通过nn.Module搭建

获取模型摘要


搭建线性层

使用 nn 包搭建线性层。线性层接收 64*1000 维的输入,保存 1000*100 维的权重,并计算 64*100 维的输出。

import torch
from torch import nn
input_tensor = torch.randn(64, 1000)
linear_layer = nn.Linear(1000, 100)
output = linear_layer(input_tensor)
print(input_tensor.size())
print(output.size())

通过nn.Sequential搭建

考虑一个两层的神经网络,四个节点作为输入,五个节点在隐藏层,一个节点作为输出

from torch import nn
model = nn.Sequential(nn.Linear(4, 5),nn.ReLU(),nn.Linear(5, 1),
)
print(model)

通过nn.Module搭建

在 PyTorch 中搭建模型的另一种方法是对 nn.Module 类进行子类化,通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

考虑两个卷积层和两个完全连接层搭建的模型:

import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net, self).__init__()def forward(self, x):pass

定义__init__ 函数和forward 函数

def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)
def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

重写两个类函数并打印模型

重写:子类中实现一个与父类的成员函数原型完全相同的函数

Net.__init__ = __init__
Net.forward = forward
model = Net()
print(model)

 查看模型位置

print(next(model.parameters()).device)

 

将模型移动至CUDA设备 

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

获取模型摘要

借助torchsummary包查获取模型摘要

pip install torchsummary
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

http://www.dtcms.com/a/289924.html

相关文章:

  • AI效能之AI单测(一)
  • MCP协议解析:如何通过Model Context Protocol 实现高效的AI客户端与服务端交互
  • c++ duiLib 使用xml文件编写界面布局
  • MyBatis Plus高效开发指南
  • 【PyTorch】图像二分类项目
  • JWT原理及利用手法
  • XTTS实现语音克隆:精确控制音频格式与生成流程【TTS的实战指南】
  • `SearchTransportService` 是 **协调节点与数据节点之间“搜索子请求”通信的运输层**
  • 如何用immich将苹果手机中的照片备份到指定文件夹
  • 开发工具缓存目录
  • 零基础学习性能测试第一章:核心性能指标-响应时间
  • 单链表的手动实现+相关OJ题
  • PostgreSQL 字段类型速查与 Java 枚举映射
  • 【硬件】GalaxyTabPro10.1(SM-T520)刷机/TWRP/LineageOS14/安卓7升级全过程
  • 讲座|人形机器人多姿态站起控制HoST及宇树G1部署
  • C++ 并发 future, promise和async
  • 2025年AIR SCI1区TOP,缩减因子分数阶蜣螂优化算法FORDBO,深度解析+性能实测
  • 基于51单片机的温湿度检测系统Protues仿真设计
  • 创建一个触发csrf的恶意html
  • 低速信号设计之I3C篇
  • windows11环境配置torch-points-kernels库编译安装详细教程
  • 【前端】懒加载(组件/路由/图片等)+预加载 汇总
  • NJU 凸优化导论(10) Approximation+Projection逼近与投影的应用(完结撒花)
  • InfluxDB 数据模型:桶、测量、标签与字段详解(二)
  • springboot --大事件--文章管理接口开发
  • 简洁高效的C++终端日志工具类
  • 响应式编程入门教程第七节:响应式架构与 MVVM 模式在 Unity 中的应用
  • SEO中关于关键词分类与布局的方法有那些
  • 【实战1】手写字识别 Pytoch(更新中)
  • Codes 通过创新的重新定义 SaaS 模式,专治 “原教旨主义 SaaS 的水土不服