当前位置: 首页 > news >正文

PyTorch中 nn.Linear详解和实战示例

1. nn.Linear 的作用

在 PyTorch 中,torch.nn.Linear 表示一个全连接层(Fully Connected Layer),也叫 仿射变换层(Affine Layer)
它的计算公式是:

y=xWT+b y = x W^T + b y=xWT+b

  • 输入 x:形状 [batch_size, in_features]
  • 权重矩阵 W:形状 [out_features, in_features]
  • 偏置 b:形状 [out_features]
  • 输出 y:形状 [batch_size, out_features]

2. 初始化方式

torch.nn.Linear(in_features: int,out_features: int,bias: bool = True,device=None,dtype=None
)
  • in_features:输入特征维度
  • out_features:输出特征维度
  • bias:是否使用偏置项(默认 True
  • device/dtype:指定设备与数据类型

3. 参数说明

一个 nn.Linear 层包含两个参数(均可训练):

  1. weight:形状 [out_features, in_features]
  2. bias:形状 [out_features](可选)

初始化时:

  • 权重 weight 默认使用 Kaiming 均匀分布初始化(a=√5)

  • 偏置 bias 默认使用 均匀分布 U(-bound, bound),其中

    bound=1in_features bound = \frac{1}{\sqrt{in\_features}} bound=in_features1


4. 前向传播公式

假设输入张量 x 形状为 [batch_size, in_features]

output[i]=∑j=1in_featuresx[j]⋅W[i][j]+b[i] \text{output}[i] = \sum_{j=1}^{in\_features} x[j] \cdot W[i][j] + b[i] output[i]=j=1in_featuresx[j]W[i][j]+b[i]

即对每个样本进行线性变换。

PyTorch 内部实现是:

output = input.matmul(weight.T) + bias

5. 反向传播(梯度)

PyTorch 自动求导会自动处理梯度,但核心推导如下:

  • 输入 x ∈ R^{B×I},权重 W ∈ R^{O×I},偏置 b ∈ R^{O}

  • 前向传播:

    Y=XWT+b Y = X W^T + b Y=XWT+b

  • 梯度:

    • 对权重:

      ∂L∂W=∂L∂YTX \frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y}^T X WL=YLTX

    • 对偏置:

      ∂L∂b=∑samples∂L∂Y \frac{\partial L}{\partial b} = \sum_{samples} \frac{\partial L}{\partial Y} bL=samplesYL

    • 对输入:

      ∂L∂X=∂L∂YW \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W XL=YLW


6. 使用示例

import torch
import torch.nn as nn# 定义线性层
fc = nn.Linear(in_features=5, out_features=3, bias=True)# 输入张量 [batch=2, in_features=5]
x = torch.randn(2, 5)# 前向传播
y = fc(x)
print("Input shape:", x.shape)   # [2, 5]
print("Output shape:", y.shape)  # [2, 3]# 查看参数
print(fc.weight.shape)  # [3, 5]
print(fc.bias.shape)    # [3]

7. 常见用法

  1. 作为全连接层

    model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 10)
    )
    
  2. 替代矩阵乘法

    W = torch.randn(3, 5)
    b = torch.randn(3)
    x = torch.randn(2, 5)y1 = x @ W.T + b
    y2 = nn.Linear(5, 3)(x)print(torch.allclose(y1, y2, atol=1e-6))  # True
    
  3. 作为嵌入层最后一步投影

    • Transformer 中 decoder 的输出用 nn.Linear 投影到词表大小 vocab_size

8. 源码关键点

PyTorch 源码(torch/nn/modules/linear.py)核心部分:

def forward(self, input: Tensor) -> Tensor:return F.linear(input, self.weight, self.bias)

其中 F.linear 实现就是 input.matmul(weight.T) + bias


9. 常见坑点

  1. 输入维度不对
    nn.Linear 要求输入最后一维是 in_features
    如果输入是 [batch, channels, height, width],要先 flattenpermute

    x = torch.randn(32, 3, 28, 28)
    fc = nn.Linear(3*28*28, 100)
    y = fc(x.view(32, -1))  # 展平
    
  2. 权重转置
    注意公式是 y = x @ W^T,而不是 x @ W

  3. 和卷积的区别

    • nn.Conv2d:局部连接 + 权重共享
    • nn.Linear:全连接,不共享权重

10. 总结

  • nn.Linear = 全连接层 = 仿射变换
  • 参数:weight [out, in]bias [out]
  • 前向公式:y = x @ W^T + b
  • 常用于:MLP、分类器最后一层、Transformer 投影层等
  • 注意输入最后一维要匹配 in_features

11. 综合应用示例

下面是一个完整的综合示例,涵盖以下内容:

定义 nn.Linear
模拟输入数据
前向传播
查看权重与偏置
反向传播 + 梯度查看
x @ W^T + b 对比验证一致性
图示化输入输出形状变化(文字+可视化)


综合示例:手写数字分类 MLP(含 nn.Linear

我们构建一个简单的 2 层感知机,模拟对输入向量进行分类。


Step 1:导入依赖 & 定义模型

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的两层感知机(MLP)
class SimpleMLP(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(16, 8)   # 第一层:16 -> 8self.fc2 = nn.Linear(8, 4)    # 第二层:8 -> 4(比如分类 4 类)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x

Step 2:模拟输入数据

# 模拟输入数据:batch_size = 3,feature_dim = 16
x = torch.randn(3, 16)# 实例化模型
model = SimpleMLP()# 前向传播
output = model(x)print("输入形状:", x.shape)        # [3, 16]
print("第一层权重形状:", model.fc1.weight.shape)  # [8, 16]
print("第二层权重形状:", model.fc2.weight.shape)  # [4, 8]
print("输出形状:", output.shape)  # [3, 4]

Step 3:手动验证 x @ W^T + bnn.Linear 一致性

# 取第一层验证
fc = model.fc1
x_input = x# 手动计算:y = x @ W^T + b
manual_output = x_input @ fc.weight.T + fc.bias# 与 forward 一致性验证
auto_output = fc(x_input)print("是否一致:", torch.allclose(manual_output, auto_output, atol=1e-6))

Step 4:反向传播 + 查看梯度

# 假设一个简单的损失函数
target = torch.tensor([0, 1, 3])     # 假设分类标签
criterion = nn.CrossEntropyLoss()# 正向计算输出
out = model(x)# 计算损失
loss = criterion(out, target)# 反向传播
loss.backward()# 查看第一层权重梯度
print("fc1 权重梯度形状:", model.fc1.weight.grad.shape)
print("fc1 偏置梯度形状:", model.fc1.bias.grad.shape)

输入输出形状变化总结

层级输入形状权重形状输出形状
输入[3, 16][3, 16]
fc1[3, 16][8, 16][3, 8]
ReLU[3, 8][3, 8]
fc2[3, 8][4, 8][3, 4]

可视化理解(流程图)

下面这个示意图帮助你直观理解 nn.Linear 是如何做维度映射的:

输入张量 x:         [batch_size=3, in_features=16]│▼
nn.Linear(16 → 8):  权重 [8, 16],输出 [3, 8]│▼ReLU 激活│▼
nn.Linear(8 → 4):   权重 [4, 8],输出 [3, 4]│▼
分类输出 logits:    [batch_size=3, out_features=4]

模型结构和张量流动的视觉图


12.nn.Linear源码关键实现和典型应用

1. nn.Linear 源码关键实现

在 PyTorch 2.0+ 的源码中(torch/nn/modules/linear.py),核心实现非常精简:

class Linear(Module):__constants__ = ['in_features', 'out_features']in_features: intout_features: intweight: Tensorbias: Tensor | Nonedef __init__(self, in_features, out_features, bias=True, device=None, dtype=None):factory_kwargs = {'device': device, 'dtype': dtype}super().__init__()self.in_features = in_featuresself.out_features = out_features# 权重参数(out_features x in_features)self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))# 偏置参数(out_features)if bias:self.bias = Parameter(torch.empty(out_features, **factory_kwargs))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self) -> None:# 权重 Kaiming 均匀初始化,偏置 U(-bound, bound)init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0init.uniform_(self.bias, -bound, bound)def forward(self, input: Tensor) -> Tensor:return F.linear(input, self.weight, self.bias)

关键点拆解:

  1. 参数存储

    • self.weight[out_features, in_features]
    • self.bias[out_features]
  2. 初始化

    • 权重:Kaiming Uniform(适合 ReLU 激活)
    • 偏置:Uniform(-1/√fan_in, 1/√fan_in)
  3. 前向计算

    def linear(input, weight, bias=None):return input.matmul(weight.T) + bias
    
  4. 梯度计算
    PyTorch 自动在 C++/CUDA backend 里定义好了 matmuladd 的梯度传播,不需要 Python 层手写。


2. nn.Linear 在不同架构中的用途

(1) 在 MLP(多层感知机)
  • 作用:核心构建模块,层层映射特征维度。

  • 例子

    model = nn.Sequential(nn.Linear(784, 256),  # 输入层 (28*28)nn.ReLU(),nn.Linear(256, 128),  # 隐藏层nn.ReLU(),nn.Linear(128, 10)    # 输出层 (分类)
    )
    
  • 解释:每个 nn.Linear 就是一次仿射变换,把输入映射到新空间。最后一层通常对应分类 logits。


(2) 在 CNN 分类器
  • 作用:CNN 负责提取空间特征,最后通过 nn.Linear 将卷积特征映射到分类输出空间。

  • 例子(ResNet 中的最后一层):

    class CNNClassifier(nn.Module):def __init__(self, num_classes=10):super().__init__()self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc   = nn.Linear(64, num_classes)  # 分类层def forward(self, x):x = F.relu(self.conv(x))x = self.pool(x)x = torch.flatten(x, 1)  # [batch, 64]x = self.fc(x)           # [batch, num_classes]return x
    
  • 解释nn.Linear 负责 卷积特征 → 类别预测


(3) 在 Transformer(如 BERT, GPT)

Transformer 内部大量用到 nn.Linear,主要场景有:

a. Attention 的 Q、K、V 投影
self.q_proj = nn.Linear(d_model, d_k)  # 生成 Query
self.k_proj = nn.Linear(d_model, d_k)  # 生成 Key
self.v_proj = nn.Linear(d_model, d_v)  # 生成 Value
  • 输入:[batch, seq_len, d_model]
  • 输出:[batch, seq_len, d_k]
  • 用于将 embedding 投影到不同子空间。
b. Attention 输出的投影
self.out_proj = nn.Linear(d_v, d_model)
  • 将多头拼接后的结果映射回 d_model 维度。
c. Feed-Forward 网络(FFN)

Transformer Block 里的 FFN 是:

FFN(x)=Linear(dmodel,dff)→ReLU/GELU→Linear(dff,dmodel) FFN(x) = \text{Linear}(d_{model}, d_{ff}) \to \text{ReLU/GELU} \to \text{Linear}(d_{ff}, d_{model}) FFN(x)=Linear(dmodel,dff)ReLU/GELULinear(dff,dmodel)

例子:

self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
d. BERT 最后一层分类头
  • BERT 用 nn.Linear(hidden_size, vocab_size) 投影到词表维度,得到预测 logits。

  • 例如 Masked LM 任务:

    self.cls = nn.Linear(hidden_size, vocab_size)
    

3. 小结对比表

架构nn.Linear 用途输入维度输出维度
MLP特征逐层映射[batch, in_features][batch, out_features]
CNN卷积特征 → 分类[batch, channels][batch, num_classes]
Transformer(a) Q/K/V 投影
(b) Attention 输出投影
© FFN 映射
(d) 分类/词表投影
[batch, seq, d_model][batch, seq, d_k/d_ff/d_model/vocab]
BERTMasked LM / NSP 分类头[batch, hidden_size][batch, vocab_size]

http://www.dtcms.com/a/349889.html

相关文章:

  • Java全栈开发实战:从基础到微服务的深度探索
  • [Python]库Pandas应用总结
  • PE嵌入式签名检测方法
  • 阿里开源Vivid-VR:AI视频修复新标杆,解锁内容创作新可能
  • AR远程协助:能源电力行业智能化革新
  • 一键编译安装zabbix(centos)
  • Spark面试题
  • HTTP 协议与TCP 的其他机制
  • excel 破解工作表密码
  • Python之Flask快速入门
  • Redis类型之List
  • 自然语言处理——07 BERT、ELMO、GTP系列模型
  • lesson46-1:Linux 常用指令全解析:从基础操作到高效应用
  • Docker:常用命令、以及设置别名
  • 数据挖掘 6.1 其他降维方法(不是很重要)
  • 聊聊负载均衡架构
  • 关于窗口关闭释放内存,主窗口下的子窗口关闭释放不用等到主窗口关闭>setAttribute(Qt::WA_DeleteOnClose);而且无需手动释放
  • 【Python】QT(PySide2、PyQt5):列表视图、模型、自定义委托
  • 【芯片后端设计的灵魂:Placement的作用与重要性】
  • SQL 语句拼接在 C 语言中的实现与安全性分析
  • 跨语言统一语义真理及其对NLP深层分析影响
  • 2.3零基础玩转uni-app轮播图:从入门到精通 (咸虾米总结)
  • Python 实战:内网渗透中的信息收集自动化脚本(3)
  • 苹果公司即将启动一项为期三年的计划
  • Linux应急响应一般思路(三)
  • 蜗牛播放器 Android TV:解决大屏观影痛点的利器
  • C/C++ 指针与函数
  • Tesseract OCR之页面布局分析
  • 朴素贝叶斯:用 “概率思维” 解决分类问题的经典算法
  • ​Visual Studio + UE5 进行游戏开发的常见故障问题解决