当前位置: 首页 > news >正文

神经网络——线性层

在机器学习中,线性层(Linear Layer) 是一种基础的神经网络组件,也称为全连接层(Fully Connected Layer) 或密集层(Dense Layer)

其严格的数学定义为:对输入数据执行线性变换,生成输出向量。

具体形式为:
                Y=XW+b  
其中:

  • X 是输入张量,通常形状为 [批次大小, 输入维度]

  • W 是可学习的权重矩阵,形状为 [输入维度, 输出维度]

  • b 是可学习的偏置向量,形状为 [输出维度]

  • Y 是输出张量,形状为 [批次大小, 输出维度]

核心特性

  1. 参数共享:同一层内的所有输入神经元都通过权重矩阵 W 与输出神经元相连,权重在整个输入空间中共享。

  2. 线性变换:仅能表示线性函数,因此通常与非线性激活函数(如 ReLU)组合使用,以增强模型表达能力。

  3. 特征投影:本质上是将输入特征投影到新的特征空间,输出维度决定了新空间的维度。

线性网络: 

 


 

参数

  • in_features (int) – 每个输入样本的大小

  • out_features (int) – 每个输出样本的大小

  • bias (bool) – 如果设置为 False,该层将不学习加性偏置。默认值: True

 

 代码举例

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../torchvision_dataset", train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset, batch_size=64)class MyModule(nn.Module):def __init__(self):super().__init__()"""下文:展开前:torch.Size([64, 3, 32, 32])展开后:torch.Size([1, 1, 1, 196608])"""self.linear = Linear(196608, 10)def forward(self, input):output = self.linear(input)return outputmodule = MyModule()for data in dataloader:imgs, targets = dataprint("原本图像尺寸", imgs.shape)# 把二维图片展开成一维的# imgs=torch.reshape(imgs,(1,1,1,-1))imgs = torch.flatten(imgs)print("展平后图像尺寸", imgs.shape)output = module(imgs)print("经过线性层处理后图像尺寸", output.shape)

 

http://www.dtcms.com/a/290348.html

相关文章:

  • 混合遗传粒子群算法在光伏系统MPPT中的应用研究
  • imx6ull-系统移植篇15——U-Boot 图形化配置(下)
  • 蚂蚁数科AI数据产业基地正式投产,携手苏州推进AI产业落地
  • 使用Python绘制专业柱状图:Matplotlib完全指南
  • 《Linux服务与安全管理》| 安装拼音输入法
  • 【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 主页布局实现
  • “hidden act“:“gelu“在bert中作用
  • 经典神经网络(vgg resnet googlenet)
  • 家庭网络怎么进行公网IP获取,及内网端口映射外网访问配置,附无公网IP提供互联网连接方案
  • 03-虚幻引擎蓝图类的各父类作用讲解
  • el-table固定高度,数据多出现滚动条,表头和内容对不齐
  • Eltable tree形式,序号列实现左对齐,并且每下一层都跟上一层的错位距离拉大
  • 深入解析Hadoop MapReduce Shuffle过程:从环形缓冲区溢写到Sort与Merge源码
  • VMware Workstation Pro克隆虚拟机导致网络异常解决方法
  • 深度学习 pytorch图像分类(详细版)
  • 【设计模式】观察者模式 (发布-订阅模式,模型-视图模式,源-监听器模式,从属者模式)
  • HTTP性能优化:打造极速Web体验的关键策略
  • 从实践出发--探究C/C++空类的大小,真的是1吗?
  • 西门子 S7-1500 信号模块硬件配置全解析:从选型到实战
  • 如何快速比较excel两列,拿出不同的数据
  • 在.NET Core API 微服务中使用 gRPC:从通信模式到场景选型
  • 用 STM32 的 SYSTICK 定时器与端口复用重映射玩转嵌入式开发
  • 大模型高效适配:软提示调优 Prompt Tuning
  • The Survey of Few-shot Prompt Learning on Graph
  • AI Agent开发学习系列 - langchain之LCEL(3):Prompt+LLM
  • JavaScript Promise全解析
  • Prompt Engineering(提示词工程)基础了解
  • 【PTA数据结构 | C语言版】列出连通集
  • 归并排序:优雅的分治排序算法(C语言实现)
  • 什么是商业智能BI数据分析的指标爆炸?