当前位置: 首页 > news >正文

PyTorch中Linear全连接层

在 PyTorch 中,torch.nn.Linear 是一个实现全连接层(线性变换)的模块,用于神经网络中的线性变换操作。它的数学表达式为:

其中:

  • x是输入数据

  • W是权重矩阵

  • b是偏置项

  • y是输出数据

基本用法

import torch
import torch.nn as nn

# 创建一个线性层,输入特征数为5,输出特征数为3
linear_layer = nn.Linear(in_features=5, out_features=3)

# 创建一个随机输入张量(batch_size=2, 特征数=5)
input_tensor = torch.randn(2, 5)

# 前向传播
output = linear_layer(input_tensor)
print(output.shape)  # 输出 torch.Size([2, 3])

主要参数

  1. in_features - 输入特征的数量

  2. out_features - 输出特征的数量

  3. bias - 是否使用偏置项(默认为True)

重要属性

  1. weight - 可学习的权重参数(形状为[out_features, in_features])

  2. bias - 可学习的偏置参数(形状为[out_features])

示例:构建简单神经网络

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)  # 输入10维,输出20维
        self.fc2 = nn.Linear(20, 2)   # 输入20维,输出2维
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()
input_data = torch.randn(5, 10)  # batch_size=5
output = model(input_data)
print(output.shape)  # torch.Size([5, 2])

初始化权重

# 自定义权重初始化
nn.init.xavier_uniform_(linear_layer.weight)
nn.init.zeros_(linear_layer.bias)

# 或者使用PyTorch内置初始化
linear_layer = nn.Linear(5, 3)
torch.nn.init.kaiming_normal_(linear_layer.weight, mode='fan_out')

注意事项

  1. 输入数据的最后一维必须等于in_features

  2. 线性层通常与激活函数配合使用(如ReLU)

  3. 在GPU上使用时,确保数据和模型都在同一设备上。

http://www.dtcms.com/a/108628.html

相关文章:

  • 视频设备轨迹回放平台EasyCVR如何搭建公共娱乐场所远程视频监控系统
  • 铁路语义分割数据下载RailSem19: A Dataset for Semantic Rail Scene Understanding
  • 使用Android 原生LocationManager获取经纬度
  • 教育软件 UI 设计:打造吸睛又实用的学习入口
  • SELinux
  • Leetcode-100 二分查找常见操作总结
  • 数据点燃创新引擎:数据驱动的产品开发如何重塑未来?
  • Airflow量化入门系列:第一章 Apache Airflow 基础
  • 红宝书第二十五讲:客户端存储(Cookie、localStorage、IndexedDB):浏览器里的“记忆盒子”
  • Leetcode 6233 -- DFS序列 | 两遍DFS
  • Vue中JSEncrypt 数据加密和解密处理
  • Firefox账号同步书签不一致(火狐浏览器书签同步不一致)
  • wireshak抓手机包 wifi手机抓包工具
  • linux 时钟
  • 【爬虫】网页抓包工具--Fiddler
  • 【Audio开发二】Android原生音量曲线调整说明
  • LInux基础指令(二)
  • 【VS+Qt】vs2022打开 vs2015项目
  • FastAPI中Pydantic异步分布式唯一性校验
  • 机器视觉调试——现场链接相机(解决各种相机链接问题)
  • 自然语言处理(22:(第六章2.)​seq2seq模型的实现​)
  • 图片懒加载、无限滚动加载、监听元素进入视口加载数据。「IntersectionObserver」
  • scala编程语言
  • 服务器数据恢复—Raid6阵列硬盘故障掉线,上层虚拟机数据如何恢复?
  • linux-firewalld防火墙允许端口
  • 【SLAM经典算法详解】Ubuntu 20.04部署LeGO-LOAM:从环境配置到KITTI适配,解决常见编译错误
  • 从零开发美颜SDK:美颜滤镜API的核心技术与实现
  • 多视图几何--立体校正--Fusiello方法
  • CMake学习--如何在CMake中编译静态库、动态库并在主程序中调用
  • rag精细化测试