CFA: Coupled-hypersphere-based Feature Adaptation 论文解析
CFA: Coupled-hypersphere-based Feature Adaptation 论文解析
论文概述
CFA (Coupled-hypersphere-based Feature Adaptation) 是一种面向目标的异常定位方法,主要用于工业缺陷检测。该方法通过耦合超球面特征适应来实现精确的异常定位,在MVTec AD数据集上取得了SOTA性能。
https://arxiv.org/pdf/2206.04325
https://github.com/sungwool/CFA_for_anomaly_localization/tree/main
核心创新点分析
1. 耦合超球面机制 (Coupled Hypersphere)
核心代码位置: utils/cfa.py
中的 DSVDD
类
1.1 双超球面设计
# 关键参数设置
self.K = 3 # 吸引邻居数量
self.J = 3 # 排斥邻居数量
self.r = nn.Parameter(1e-5*torch.ones(1), requires_grad=True) # 可学习的超球面半径
创新意义:
- 传统SVDD只使用单一超球面,CFA引入了吸引-排斥双机制
K
个最近邻用于吸引正常样本(拉向球心)J
个次近邻用于排斥边界样本(推离球心)- 这种设计使决策边界更加精确和鲁棒
1.2 软边界损失函数
def _soft_boundary(self, phi_p):# 计算到聚类中心的距离features = torch.sum(torch.pow(phi_p, 2), 2, keepdim=True)centers = torch.sum(torch.pow(self.C, 2), 0, keepdim=True)f_c = 2 * torch.matmul(phi_p, (self.C))dist = features + centers - f_c# 选择K+J个最近邻n_neighbors = self.K + self.Jdist = dist.topk(n_neighbors, largest=False).values# 吸引损失:拉近K个最近邻score = (dist[:, : , :self.K] - self.r**2) L_att = (1/self.nu) * torch.mean(torch.max(torch.zeros_like(score), score))# 排斥损失:推远J个次近邻score = (self.r**2 - dist[:, : , self.J:]) L_rep = (1/self.nu) * torch.mean(torch.max(torch.zeros_like(score), score - self.alpha))return L_att + L_rep
创新意义:
- L_att (吸引损失): 确保正常样本聚集在超球面内部
- L_rep (排斥损失): 防止边界样本过于接近球心,增强判别能力
- 双损失机制形成了更紧致和判别性的特征表示
2. 坐标感知特征适应 (Coordinate-aware Feature Adaptation)
核心代码位置: utils/coordconv.py
和 utils/cfa.py
中的 Descriptor
类
2.1 CoordConv机制
class CoordConv2d(conv.Conv2d):def __init__(self, in_channels, out_channels, kernel_size, ...):super().__init__(...)self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda)# 增加坐标通道self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, ...)def forward(self, input_tensor):# 添加空间坐标信息out = self.addcoords(input_tensor)out = self.conv(out)return out
创新意义:
- 传统卷积对空间位置不敏感,CoordConv通过显式添加坐标信息解决这个问题
- 在特征图中嵌入
(x, y)
坐标和可选的径向距离r
- 使模型能够感知空间位置关系,对异常定位至关重要
2.2 多尺度特征融合
def forward(self, p):sample = Nonefor o in p:# 不同层的特征进行平均池化o = F.avg_pool2d(o, 3, 1, 1) / o.size(1) if self.cnn == 'effnet-b5' else F.avg_pool2d(o, 3, 1, 1)# 插值到统一尺寸并拼接sample = o if sample is None else torch.cat((sample, F.interpolate(o, sample.size(2), mode='bilinear')), dim=1)# 通过CoordConv进行特征适应phi_p = self.layer(sample)return phi_p
创新意义:
- 融合CNN不同层的多尺度特征
- 低层特征捕获细节,高层特征捕获语义
- 通过CoordConv增强空间感知能力
3. 自适应聚类中心初始化
核心代码位置: utils/cfa.py
中的 _init_centroid
方法
def _init_centroid(self, model, data_loader):for i, (x, _, _) in enumerate(tqdm(data_loader)):x = x.to(self.device)p = model(x)self.scale = p[0].size(2) # 记录特征图尺寸phi_p = self.Descriptor(p)# 增量式更新聚类中心self.C = ((self.C * i) + torch.mean(phi_p, dim=0, keepdim=True).detach()) / (i+1)
后处理聚类:
if self.gamma_c > 1:self.C = self.C.cpu().detach().numpy()# 使用K-means进一步优化聚类中心self.C = KMeans(n_clusters=(self.scale**2)//self.gamma_c, max_iter=3000).fit(self.C).cluster_centers_self.C = torch.Tensor(self.C).to(device)
创新意义:
- 增量式初始化避免内存溢出
- gamma_c参数控制聚类中心数量,平衡精度和效率
- 自适应的聚类中心更好地表示正常样本分布
4. 软最小值异常评分
核心代码位置: utils/cfa.py
中的 forward
方法
def forward(self, p):phi_p = self.Descriptor(p) phi_p = rearrange(phi_p, 'b c h w -> b (h w) c')# 计算到所有聚类中心的距离features = torch.sum(torch.pow(phi_p, 2), 2, keepdim=True) centers = torch.sum(torch.pow(self.C, 2), 0, keepdim=True)f_c = 2 * torch.matmul(phi_p, (self.C))dist = features + centers - f_cdist = torch.sqrt(dist)# 选择K个最近邻n_neighbors = self.Kdist = dist.topk(n_neighbors, largest=False).values# 软最小值加权dist = (F.softmin(dist, dim=-1)[:, :, 0]) * dist[:, :, 0]dist = dist.unsqueeze(-1)# 重塑为空间维度score = rearrange(dist, 'b (h w) c -> b c h w', h=self.scale)return loss, score
创新意义:
- 使用softmin加权而非简单的最小距离
- 考虑多个近邻的贡献,使异常评分更加平滑和鲁棒
- 输出像素级异常得分图,实现精确定位
技术架构总览
整体流程
- 特征提取: 使用预训练CNN(ResNet/WideResNet/EfficientNet/VGG)
- 特征适应: 通过Descriptor模块融合多尺度特征并增加坐标信息
- 聚类中心学习: 在训练过程中学习正常样本的聚类中心
- 异常检测: 计算测试样本到聚类中心的距离,生成异常得分
关键参数
- gamma_c: 控制聚类中心数量的压缩比例
- gamma_d: 控制特征维度的压缩比例
- K: 吸引邻居数量(默认3)
- J: 排斥邻居数量(默认3)
实验性能
在MVTec AD数据集上的表现:
- 图像级AUROC: 99.5%
- 像素级AUROC: 98.5%
- 像素级AUPRO: 高性能表现
代码实现亮点
1. 高效的距离计算
使用矩阵运算而非循环,充分利用GPU并行计算能力。
2. 内存友好的设计
- 增量式聚类中心更新
- 合理的批处理大小设计
- 梯度计算的精细控制
3. 模块化架构
- 清晰的组件分离(DSVDD、Descriptor、CoordConv)
- 支持多种CNN骨干网络
- 易于扩展和修改
创新总结
CFA的核心创新在于:
- 耦合超球面机制: 吸引-排斥双机制提升决策边界精度
- 坐标感知特征: CoordConv增强空间位置感知能力
- 软边界优化: 平衡拟合能力和泛化性能
- 多尺度融合: 结合不同层次的特征信息
这些创新点协同工作,使CFA在工业异常检测任务中达到了SOTA性能,特别在精确定位方面表现突出。