「日拱一码」125 多层特征融合
目录
机器学习中的多层特征融合(Multi-Level Feature Fusion)
核心概念
典型融合方式
代码示例
机器学习中的多层特征融合(Multi-Level Feature Fusion)
核心概念
多层特征融合是指整合神经网络不同深度的特征,结合浅层的细节信息(如边缘、纹理)和深层的语义信息(如物体类别),提升模型性能。常见于:
- 计算机视觉:U-Net、FPN(特征金字塔网络)
- 自然语言处理:Transformer的多头注意力融合
- 多模态学习:融合视觉+文本特征
典型融合方式
方法 | 特点 | 适用场景 |
---|---|---|
拼接(Concatenation) | 沿通道维度拼接特征图 | 需保留所有原始信息时 |
相加(Summation) | 逐元素相加,要求特征图形状一致 | 残差连接、特征增强 |
注意力加权融合 | 动态学习不同层级特征的权重 | 复杂背景下的关键特征提取 |
代码示例
import torch
import torch.nn as nn
import torch.nn.functional as F## 无特征融合的基准模型
class BaselineModel(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(3, 64, 3, stride=2, padding=1), # [B,64,H/2,W/2]nn.ReLU(),nn.Conv2d(64, 128, 3, stride=2, padding=1), # [B,128,H/4,W/4]nn.ReLU(),nn.Conv2d(128, 256, 3, stride=2, padding=1) # [B,256,H/8,W/8])self.fc = nn.Linear(256 * 32 * 32, 10) #输入256x256时,最终特征图为32x32def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)return self.fc(x)## 带特征融合的改进模型
class FusionModel(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 64, 3, padding=1) # [B,64,256,256]self.conv2 = nn.Sequential(nn.MaxPool2d(2), # [B,64,128,128]nn.Conv2d(64, 128, 3, padding=1) # [B,128,128,128])self.conv3 = nn.Sequential(nn.MaxPool2d(2), # [B,128,64,64]nn.Conv2d(128, 256, 3, padding=1) # [B,256,64,64])self.fusion = nn.Sequential(nn.Conv2d(64 + 128 + 256, 256, 1), # [B,256,64,64]nn.ReLU())self.avgpool = nn.AdaptiveAvgPool2d((8, 8)) # [B,256,8,8]self.fc = nn.Linear(256 * 8 * 8, 10) # 输入固定为256*8*8def forward(self, x):feat1 = torch.relu(self.conv1(x)) # [B,64,256,256]feat2 = torch.relu(self.conv2(feat1)) # [B,128,128,128]feat3 = torch.relu(self.conv3(feat2)) # [B,256,64,64]# 上采样到统一尺寸(64x64)feat1_down = F.max_pool2d(feat1, 4) # [B,64,64,64]feat2_up = F.interpolate(feat2, scale_factor=0.5) # [B,128,64,64]# 拼接融合fused = torch.cat([feat1_down, feat2_up, feat3], dim=1) # [B,448,64,64]x = self.fusion(fused) # [B,256,64,64]x = self.avgpool(x) # [B,256,8,8]x = x.view(x.size(0), -1) # [B,256*8*8]return self.fc(x) # [B,10]inputs = torch.randn(4, 3, 256, 256)
labels = torch.randint(0, 10, (4,))def compare_models():models = {"Baseline": BaselineModel(), "Fusion": FusionModel()}for name, model in models.items():outputs = model(inputs)print(f"{name} Output:", outputs)compare_models()# Baseline Output: tensor([[ 0.0843, -0.0305, 0.0694, 0.0418, -0.1137, -0.0166, 0.0289, -0.0310,
# 0.0460, 0.0113],
# [ 0.1234, 0.0051, 0.0059, 0.0220, -0.0471, 0.0165, 0.0634, 0.0209,
# 0.0759, -0.0172],
# [ 0.1295, -0.1292, 0.0489, 0.0084, -0.0774, -0.0134, 0.1163, 0.0403,
# -0.0300, -0.0682],
# [ 0.0558, -0.1122, -0.0177, 0.0377, -0.0714, -0.0063, 0.0815, -0.0592,
# 0.0967, 0.0269]], grad_fn=<AddmmBackward0>)
# Fusion Output: tensor([[ 0.0748, 0.0692, -0.0581, -0.0391, -0.0746, 0.1133, 0.1727, 0.0033,
# -0.0519, -0.0664],
# [ 0.0730, 0.0777, -0.0526, -0.0416, -0.0747, 0.1068, 0.1723, -0.0005,
# -0.0663, -0.0601],
# [ 0.0692, 0.0694, -0.0557, -0.0384, -0.0776, 0.1115, 0.1748, 0.0015,
# -0.0587, -0.0629],
# [ 0.0655, 0.0752, -0.0608, -0.0405, -0.0844, 0.1281, 0.1759, -0.0033,
# -0.0564, -0.0591]], grad_fn=<AddmmBackward0>)