当前位置: 首页 > news >正文

残差收缩模块

1. 多尺度阈值生成

创新思路:融合不同尺度的统计信息(如平均池化+最大池化)生成更鲁棒的阈值。

class MultiScaleShrinkage(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        # 多尺度池化分支
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        
        # 双分支融合
        self.fc = nn.Sequential(
            nn.Linear(channel*2, channel//reduction),
            nn.ReLU(),
            nn.Linear(channel//reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        residual = x
        x_abs = torch.abs(x)
        
        # 双分支池化
        avg = self.avg_pool(x_abs).squeeze(-1)
        max_ = self.max_pool(x_abs).squeeze(-1)
        combined = torch.cat([avg, max_], dim=-1)  # (B, 2C)
        
        threshold = self.fc(combined).unsqueeze(-1)
        # 后续软阈值处理相同

2. 空间-通道协同阈值化

创新思路:引入空间注意力机制,实现通道与空间联合自适应。

class SpatioChannelShrinkage(nn.Module):
    def __init__(self, channel, reduction=4):
        super().__init__()
        # 通道分支
        self.channel_fc = nn.Sequential(
            nn.Linear(channel, channel//reduction),
            nn.ReLU(),
            nn.Linear(channel//reduction, channel),
            nn.Sigmoid()
        )
        
        # 空间分支(1D卷积)
        self.spatial_conv = nn.Sequential(
            nn.Conv1d(channel, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        residual = x
        x_abs = torch.abs(x)
        
        # 通道阈值
        channel_avg = x_abs.mean(-1)  # (B,C)
        channel_th = self.channel_fc(channel_avg).unsqueeze(-1)  # (B,C,1)
        
        # 空间阈值
        spatial_th = self.spatial_conv(x_abs)  # (B,1,L)
        
        # 联合阈值
        combined_th = channel_th * spatial_th  # (B,C,L)
        
        # 动态软阈值
        sub = x_abs - combined_th
        return torch.sign(residual) * torch.clamp_min(sub, 0)

3. 可微硬阈值化

创新思路:通过自适应选择软/硬阈值化,增强特征选择性。

class AdaptiveThreshold(nn.Module):
    def __init__(self, channel):
        super().__init__()
        # 可学习阈值比例系数
        self.alpha = nn.Parameter(torch.randn(1, channel, 1))
        
    def forward(self, x):
        x_abs = torch.abs(x)
        threshold = self.alpha * x_abs.mean(-1, keepdim=True)
        
        # 硬阈值直通式梯度
        mask = (x_abs > threshold).float()
        return x * mask

4. 轻量化动态卷积阈值

创新思路:用深度可分离卷积替代全连接层,减少参数量。

class LightShrinkage(nn.Module):
    def __init__(self, channel):
        super().__init__()
        # 深度可分离卷积
        self.dw_conv = nn.Sequential(
            nn.Conv1d(channel, channel, kernel_size=3, 
                      padding=1, groups=channel),
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channel, channel, kernel_size=1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        residual = x
        x_abs = torch.abs(x)
        
        # 通过卷积提取局部模式
        threshold = self.dw_conv(x_abs)
        sub = x_abs - threshold
        return torch.sign(residual) * torch.relu(sub)

5. 残差收缩增强

创新思路:引入残差连接避免信息丢失,增强梯度流动。

class ResidualShrinkage(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.shrink = Shrinkage(channel)  # 原收缩模块
        
    def forward(self, x):
        return x + self.shrink(x)  # 残差连接

创新方向总结

方向关键改进适用场景
多尺度统计融合平均/最大池化高噪声数据
空间-通道协同1D卷积+通道注意力需要局部上下文的任务
软硬阈值结合可学习阈值类型需精确特征选择的场景
轻量化设计深度可分离卷积移动端/实时处理
残差增强收缩结果与原始输入相加深层网络训练稳定性

建议通过消融实验验证不同改进方案的有效性,根据具体任务选择最佳组合。例如,对于高噪声时序信号处理,多尺度+空间通道协同的方案可能更有效;而对于计算资源受限的场景,轻量化设计更为合适。

相关文章:

  • 大数据测试中,数据仓库表类型有哪些?
  • 深度学习中关于超参数的解释
  • vm+centos虚拟机
  • Kotlin中RxJava用法
  • SQL 中为什么参数多了not in 比 in 慢多了,怎么优化
  • JavaScript系列05-现代JavaScript新特性
  • .NET10 - 预览版1新功能体验(一)
  • Generalized Sparse Additive Model with Unknown Link Function
  • vue全局注册组件
  • Y3学习打卡
  • 【3-3】springcloud
  • 【每日学点HarmonyOS Next知识】网络请求回调toast问题、Popup问题、禁止弹窗返回、navigation折叠屏不显示返回键、响应式布局
  • Deepseek:物理神经网络PINN入门教程
  • element-push el-date-picker日期时间选择器,禁用可选中的时间 精确到分钟
  • OpenCV计算摄影学(11)色调映射算法类cv::TonemapDrago
  • 【量化策略】网格交易策略
  • 本地安装git
  • Sass基础
  • Django框架下html文件无法格式化的解决方案
  • 初识Qt · Qt的基本认识和基本项目代码解释
  • 习近平会见缅甸领导人敏昂莱
  • 上海市委常委会会议暨市生态文明建设领导小组会议研究基层减负、生态环保等事项
  • 毗邻三市人均GDP全部超过20万元,苏锡常是怎样做到的?
  • 大风暴雨致湖南岳阳县6户房屋倒塌、100多户受损
  • 司法部:持续规范行政执法行为,加快制定行政执法监督条例
  • 化学家、台湾地区“中研院”原学术副院长陈长谦逝世