残差收缩模块
1. 多尺度阈值生成
创新思路:融合不同尺度的统计信息(如平均池化+最大池化)生成更鲁棒的阈值。
class MultiScaleShrinkage(nn.Module):
def __init__(self, channel, reduction=4):
super().__init__()
# 多尺度池化分支
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.max_pool = nn.AdaptiveMaxPool1d(1)
# 双分支融合
self.fc = nn.Sequential(
nn.Linear(channel*2, channel//reduction),
nn.ReLU(),
nn.Linear(channel//reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
residual = x
x_abs = torch.abs(x)
# 双分支池化
avg = self.avg_pool(x_abs).squeeze(-1)
max_ = self.max_pool(x_abs).squeeze(-1)
combined = torch.cat([avg, max_], dim=-1) # (B, 2C)
threshold = self.fc(combined).unsqueeze(-1)
# 后续软阈值处理相同
2. 空间-通道协同阈值化
创新思路:引入空间注意力机制,实现通道与空间联合自适应。
class SpatioChannelShrinkage(nn.Module):
def __init__(self, channel, reduction=4):
super().__init__()
# 通道分支
self.channel_fc = nn.Sequential(
nn.Linear(channel, channel//reduction),
nn.ReLU(),
nn.Linear(channel//reduction, channel),
nn.Sigmoid()
)
# 空间分支(1D卷积)
self.spatial_conv = nn.Sequential(
nn.Conv1d(channel, 1, kernel_size=3, padding=1),
nn.Sigmoid()
)
def forward(self, x):
residual = x
x_abs = torch.abs(x)
# 通道阈值
channel_avg = x_abs.mean(-1) # (B,C)
channel_th = self.channel_fc(channel_avg).unsqueeze(-1) # (B,C,1)
# 空间阈值
spatial_th = self.spatial_conv(x_abs) # (B,1,L)
# 联合阈值
combined_th = channel_th * spatial_th # (B,C,L)
# 动态软阈值
sub = x_abs - combined_th
return torch.sign(residual) * torch.clamp_min(sub, 0)
3. 可微硬阈值化
创新思路:通过自适应选择软/硬阈值化,增强特征选择性。
class AdaptiveThreshold(nn.Module):
def __init__(self, channel):
super().__init__()
# 可学习阈值比例系数
self.alpha = nn.Parameter(torch.randn(1, channel, 1))
def forward(self, x):
x_abs = torch.abs(x)
threshold = self.alpha * x_abs.mean(-1, keepdim=True)
# 硬阈值直通式梯度
mask = (x_abs > threshold).float()
return x * mask
4. 轻量化动态卷积阈值
创新思路:用深度可分离卷积替代全连接层,减少参数量。
class LightShrinkage(nn.Module):
def __init__(self, channel):
super().__init__()
# 深度可分离卷积
self.dw_conv = nn.Sequential(
nn.Conv1d(channel, channel, kernel_size=3,
padding=1, groups=channel),
nn.AdaptiveAvgPool1d(1),
nn.Conv1d(channel, channel, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
residual = x
x_abs = torch.abs(x)
# 通过卷积提取局部模式
threshold = self.dw_conv(x_abs)
sub = x_abs - threshold
return torch.sign(residual) * torch.relu(sub)
5. 残差收缩增强
创新思路:引入残差连接避免信息丢失,增强梯度流动。
class ResidualShrinkage(nn.Module):
def __init__(self, channel):
super().__init__()
self.shrink = Shrinkage(channel) # 原收缩模块
def forward(self, x):
return x + self.shrink(x) # 残差连接
创新方向总结
方向 | 关键改进 | 适用场景 |
---|---|---|
多尺度统计 | 融合平均/最大池化 | 高噪声数据 |
空间-通道协同 | 1D卷积+通道注意力 | 需要局部上下文的任务 |
软硬阈值结合 | 可学习阈值类型 | 需精确特征选择的场景 |
轻量化设计 | 深度可分离卷积 | 移动端/实时处理 |
残差增强 | 收缩结果与原始输入相加 | 深层网络训练稳定性 |
建议通过消融实验验证不同改进方案的有效性,根据具体任务选择最佳组合。例如,对于高噪声时序信号处理,多尺度+空间通道协同的方案可能更有效;而对于计算资源受限的场景,轻量化设计更为合适。