PyTorch中 nn.Linear详解和实战示例
1. nn.Linear
的作用
在 PyTorch 中,torch.nn.Linear
表示一个全连接层(Fully Connected Layer),也叫 仿射变换层(Affine Layer)。
它的计算公式是:
y=xWT+b y = x W^T + b y=xWT+b
- 输入
x
:形状[batch_size, in_features]
- 权重矩阵
W
:形状[out_features, in_features]
- 偏置
b
:形状[out_features]
- 输出
y
:形状[batch_size, out_features]
2. 初始化方式
torch.nn.Linear(in_features: int,out_features: int,bias: bool = True,device=None,dtype=None
)
in_features
:输入特征维度out_features
:输出特征维度bias
:是否使用偏置项(默认True
)device/dtype
:指定设备与数据类型
3. 参数说明
一个 nn.Linear
层包含两个参数(均可训练):
weight
:形状[out_features, in_features]
bias
:形状[out_features]
(可选)
初始化时:
-
权重
weight
默认使用 Kaiming 均匀分布初始化(a=√5) -
偏置
bias
默认使用 均匀分布 U(-bound, bound),其中bound=1in_features bound = \frac{1}{\sqrt{in\_features}} bound=in_features1
4. 前向传播公式
假设输入张量 x
形状为 [batch_size, in_features]
:
output[i]=∑j=1in_featuresx[j]⋅W[i][j]+b[i] \text{output}[i] = \sum_{j=1}^{in\_features} x[j] \cdot W[i][j] + b[i] output[i]=j=1∑in_featuresx[j]⋅W[i][j]+b[i]
即对每个样本进行线性变换。
PyTorch 内部实现是:
output = input.matmul(weight.T) + bias
5. 反向传播(梯度)
PyTorch 自动求导会自动处理梯度,但核心推导如下:
-
输入
x ∈ R^{B×I}
,权重W ∈ R^{O×I}
,偏置b ∈ R^{O}
-
前向传播:
Y=XWT+b Y = X W^T + b Y=XWT+b
-
梯度:
-
对权重:
∂L∂W=∂L∂YTX \frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y}^T X ∂W∂L=∂Y∂LTX
-
对偏置:
∂L∂b=∑samples∂L∂Y \frac{\partial L}{\partial b} = \sum_{samples} \frac{\partial L}{\partial Y} ∂b∂L=samples∑∂Y∂L
-
对输入:
∂L∂X=∂L∂YW \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W ∂X∂L=∂Y∂LW
-
6. 使用示例
import torch
import torch.nn as nn# 定义线性层
fc = nn.Linear(in_features=5, out_features=3, bias=True)# 输入张量 [batch=2, in_features=5]
x = torch.randn(2, 5)# 前向传播
y = fc(x)
print("Input shape:", x.shape) # [2, 5]
print("Output shape:", y.shape) # [2, 3]# 查看参数
print(fc.weight.shape) # [3, 5]
print(fc.bias.shape) # [3]
7. 常见用法
-
作为全连接层
model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 10) )
-
替代矩阵乘法
W = torch.randn(3, 5) b = torch.randn(3) x = torch.randn(2, 5)y1 = x @ W.T + b y2 = nn.Linear(5, 3)(x)print(torch.allclose(y1, y2, atol=1e-6)) # True
-
作为嵌入层最后一步投影
- Transformer 中 decoder 的输出用
nn.Linear
投影到词表大小vocab_size
。
- Transformer 中 decoder 的输出用
8. 源码关键点
PyTorch 源码(torch/nn/modules/linear.py
)核心部分:
def forward(self, input: Tensor) -> Tensor:return F.linear(input, self.weight, self.bias)
其中 F.linear
实现就是 input.matmul(weight.T) + bias
。
9. 常见坑点
-
输入维度不对
nn.Linear
要求输入最后一维是in_features
。
如果输入是[batch, channels, height, width]
,要先flatten
或permute
。x = torch.randn(32, 3, 28, 28) fc = nn.Linear(3*28*28, 100) y = fc(x.view(32, -1)) # 展平
-
权重转置
注意公式是y = x @ W^T
,而不是x @ W
。 -
和卷积的区别
nn.Conv2d
:局部连接 + 权重共享nn.Linear
:全连接,不共享权重
10. 总结
nn.Linear
= 全连接层 = 仿射变换- 参数:
weight [out, in]
,bias [out]
- 前向公式:
y = x @ W^T + b
- 常用于:MLP、分类器最后一层、Transformer 投影层等
- 注意输入最后一维要匹配
in_features
11. 综合应用示例
下面是一个完整的综合示例,涵盖以下内容:
定义 nn.Linear
模拟输入数据
前向传播
查看权重与偏置
反向传播 + 梯度查看
与 x @ W^T + b
对比验证一致性
图示化输入输出形状变化(文字+可视化)
综合示例:手写数字分类 MLP(含 nn.Linear
)
我们构建一个简单的 2 层感知机,模拟对输入向量进行分类。
Step 1:导入依赖 & 定义模型
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的两层感知机(MLP)
class SimpleMLP(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(16, 8) # 第一层:16 -> 8self.fc2 = nn.Linear(8, 4) # 第二层:8 -> 4(比如分类 4 类)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x
Step 2:模拟输入数据
# 模拟输入数据:batch_size = 3,feature_dim = 16
x = torch.randn(3, 16)# 实例化模型
model = SimpleMLP()# 前向传播
output = model(x)print("输入形状:", x.shape) # [3, 16]
print("第一层权重形状:", model.fc1.weight.shape) # [8, 16]
print("第二层权重形状:", model.fc2.weight.shape) # [4, 8]
print("输出形状:", output.shape) # [3, 4]
Step 3:手动验证 x @ W^T + b
与 nn.Linear
一致性
# 取第一层验证
fc = model.fc1
x_input = x# 手动计算:y = x @ W^T + b
manual_output = x_input @ fc.weight.T + fc.bias# 与 forward 一致性验证
auto_output = fc(x_input)print("是否一致:", torch.allclose(manual_output, auto_output, atol=1e-6))
Step 4:反向传播 + 查看梯度
# 假设一个简单的损失函数
target = torch.tensor([0, 1, 3]) # 假设分类标签
criterion = nn.CrossEntropyLoss()# 正向计算输出
out = model(x)# 计算损失
loss = criterion(out, target)# 反向传播
loss.backward()# 查看第一层权重梯度
print("fc1 权重梯度形状:", model.fc1.weight.grad.shape)
print("fc1 偏置梯度形状:", model.fc1.bias.grad.shape)
输入输出形状变化总结
层级 | 输入形状 | 权重形状 | 输出形状 |
---|---|---|---|
输入 | [3, 16] | – | [3, 16] |
fc1 | [3, 16] | [8, 16] | [3, 8] |
ReLU | [3, 8] | – | [3, 8] |
fc2 | [3, 8] | [4, 8] | [3, 4] |
可视化理解(流程图)
下面这个示意图帮助你直观理解 nn.Linear
是如何做维度映射的:
输入张量 x: [batch_size=3, in_features=16]│▼
nn.Linear(16 → 8): 权重 [8, 16],输出 [3, 8]│▼ReLU 激活│▼
nn.Linear(8 → 4): 权重 [4, 8],输出 [3, 4]│▼
分类输出 logits: [batch_size=3, out_features=4]
模型结构和张量流动的视觉图
12.nn.Linear
的源码关键实现和典型应用
1. nn.Linear
源码关键实现
在 PyTorch 2.0+ 的源码中(torch/nn/modules/linear.py
),核心实现非常精简:
class Linear(Module):__constants__ = ['in_features', 'out_features']in_features: intout_features: intweight: Tensorbias: Tensor | Nonedef __init__(self, in_features, out_features, bias=True, device=None, dtype=None):factory_kwargs = {'device': device, 'dtype': dtype}super().__init__()self.in_features = in_featuresself.out_features = out_features# 权重参数(out_features x in_features)self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))# 偏置参数(out_features)if bias:self.bias = Parameter(torch.empty(out_features, **factory_kwargs))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self) -> None:# 权重 Kaiming 均匀初始化,偏置 U(-bound, bound)init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0init.uniform_(self.bias, -bound, bound)def forward(self, input: Tensor) -> Tensor:return F.linear(input, self.weight, self.bias)
关键点拆解:
-
参数存储
self.weight
:[out_features, in_features]
self.bias
:[out_features]
-
初始化
- 权重:
Kaiming Uniform
(适合 ReLU 激活) - 偏置:
Uniform(-1/√fan_in, 1/√fan_in)
- 权重:
-
前向计算
def linear(input, weight, bias=None):return input.matmul(weight.T) + bias
-
梯度计算
PyTorch 自动在 C++/CUDA backend 里定义好了matmul
和add
的梯度传播,不需要 Python 层手写。
2. nn.Linear
在不同架构中的用途
(1) 在 MLP(多层感知机) 中
-
作用:核心构建模块,层层映射特征维度。
-
例子:
model = nn.Sequential(nn.Linear(784, 256), # 输入层 (28*28)nn.ReLU(),nn.Linear(256, 128), # 隐藏层nn.ReLU(),nn.Linear(128, 10) # 输出层 (分类) )
-
解释:每个
nn.Linear
就是一次仿射变换,把输入映射到新空间。最后一层通常对应分类 logits。
(2) 在 CNN 分类器 中
-
作用:CNN 负责提取空间特征,最后通过
nn.Linear
将卷积特征映射到分类输出空间。 -
例子(ResNet 中的最后一层):
class CNNClassifier(nn.Module):def __init__(self, num_classes=10):super().__init__()self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(64, num_classes) # 分类层def forward(self, x):x = F.relu(self.conv(x))x = self.pool(x)x = torch.flatten(x, 1) # [batch, 64]x = self.fc(x) # [batch, num_classes]return x
-
解释:
nn.Linear
负责 卷积特征 → 类别预测。
(3) 在 Transformer(如 BERT, GPT) 中
Transformer 内部大量用到 nn.Linear
,主要场景有:
a. Attention 的 Q、K、V 投影
self.q_proj = nn.Linear(d_model, d_k) # 生成 Query
self.k_proj = nn.Linear(d_model, d_k) # 生成 Key
self.v_proj = nn.Linear(d_model, d_v) # 生成 Value
- 输入:
[batch, seq_len, d_model]
- 输出:
[batch, seq_len, d_k]
- 用于将 embedding 投影到不同子空间。
b. Attention 输出的投影
self.out_proj = nn.Linear(d_v, d_model)
- 将多头拼接后的结果映射回
d_model
维度。
c. Feed-Forward 网络(FFN)
Transformer Block 里的 FFN 是:
FFN(x)=Linear(dmodel,dff)→ReLU/GELU→Linear(dff,dmodel) FFN(x) = \text{Linear}(d_{model}, d_{ff}) \to \text{ReLU/GELU} \to \text{Linear}(d_{ff}, d_{model}) FFN(x)=Linear(dmodel,dff)→ReLU/GELU→Linear(dff,dmodel)
例子:
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
d. BERT 最后一层分类头
-
BERT 用
nn.Linear(hidden_size, vocab_size)
投影到词表维度,得到预测 logits。 -
例如 Masked LM 任务:
self.cls = nn.Linear(hidden_size, vocab_size)
3. 小结对比表
架构 | nn.Linear 用途 | 输入维度 | 输出维度 |
---|---|---|---|
MLP | 特征逐层映射 | [batch, in_features] | [batch, out_features] |
CNN | 卷积特征 → 分类 | [batch, channels] | [batch, num_classes] |
Transformer | (a) Q/K/V 投影 (b) Attention 输出投影 © FFN 映射 (d) 分类/词表投影 | [batch, seq, d_model] | [batch, seq, d_k/d_ff/d_model/vocab] |
BERT | Masked LM / NSP 分类头 | [batch, hidden_size] | [batch, vocab_size] |