1. 辅助函数
calc_ins_mean_std(x, eps=1e-12)
- 功能:计算每张图像每个通道的均值和标准差。
- 输入:一个4维张量(batch_size x channels x height x width)。
- 输出:两个3维张量,分别表示均值和标准差。均值的维度为 (batch_size, channels, 1)。
instance_norm_mix(content_feature, style_feature)
- 功能:将内容特征归一化到风格特征的分布下。
- 步骤:
- 计算content_feat和style_feat的均值和标准差。
- 使用style_feat的统计量对content_feat进行变换。
cn_rand_bbox(x, use_crop=True)
- 功能:生成随机边界框,用于crop操作。
- 参数:x为输入图像,use_crop决定是否应用 crop。
- 返回:如果use_crop为True,按比例缩放和裁剪图像,并返回新尺寸。
cn_op_2ins_space_CHAN(x, use_crop=True)
- 功能:跨实例或空间变换特征。
- 参数:输入x、是否使用crop。
- 步骤:
- 如果使用 crop,则生成随机边界框并裁剪图像。
- 将输入在通道维度上分割,交换特征图块,或按比例缩放。
2. 模块实现
CrossNorm
- 结构:包含一个操作函数
cn_op_2ins_space_CHAN
,用于跨实例的统计交换。 - 特点:
- 在训练时active为True时才应用。
- 可选参数影响是否应用crop和空间变换。
SelfNorm
- 结构:全连接层结合BN层,生成通道特定的缩放和平移参数。
- 特点:
- 使用两个1维卷积层(
g_fc
和可选的f_fc
)来生成调整因子。 - 应用sigmoid激活函数限制范围。
CNSN
- 结构:组合CrossNorm和SelfNorm模块。
- 流程:先应用CrossNorm(如果active),然后使用SelfNorm自适应调整特征。
3. 使用示例
crossnorm = CrossNorm()
selfnorm = SelfNorm(chan_num=3)
block = CNSN(crossnorm, selfnorm)
input = torch.rand(32, 3, 224, 224)
output = block(input)
print(output.size())
4. 总结
- 主要功能:CrossNorm和SelfNorm模块通过自适应调整特征图的统计量,增强模型对分布偏移的鲁棒性。
- 适用场景:在训练阶段使用这些模块可以提高模型泛化能力,而无需修改网络结构。
5. 注意事项
- 使用时需确保数据预处理标准化已正确完成,否则自适应调整可能无效。
- CrossNorm应仅在训练或特定需要的阶段启用,以避免推理时计算开销过大。
- SelfNorm提供了对每个通道更细致的控制,适用于需要通道间较强独立性的任务。
6. 代码解释
import torch
def calc_ins_mean_std(x, eps=1e-12):
mean = x.mean(dim=(0, 2, 3), keepdim=True)
std = x.std(dim=(0, 2, 3), keepdim=True).add(eps)
return mean, std
def instance_norm_mix(content_feature, style_feature):
content_mean, content_std = calc_ins_mean_std(content_feature)
style_mean, style_std = calc_ins_mean_std(style_feature)
normalized_content = (content_feature - content_mean) / content_std
transformed_content = normalized_content * style_std + style_mean
return transformed_content
def cn_rand_bbox(x, use_crop=True):
if not use_crop:
return x, None
_, _, h, w = x.size()
size = int(h // 2)
boxes = []
for i in range(0, h, size):
for j in range(0, w, size):
if i + size > h or j + size > w:
continue
boxes.append((i, j, i+size, j+size))
return x, None
def cn_op_2ins_space_CHAN(x, use_crop=True):
if not use_crop:
return x
else:
x, _ = cn_rand_bbox(x, use_crop)
return x
class CrossNorm(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
if self.training and self.active:
transformed_x = cn_op_2ins_space_CHAN(x)
return transformed_x
else:
return x
class SelfNorm(torch.nn.Module):
def __init__(self, chan_num):
super().__init__()
self.g_conv1d = torch.nn.Conv1d(chan_num, Chan_num, kernel_size=2)
self.bn = torch.nn.BatchNorm1d(Chan_num)
if is_two:
self.f_conv1d = torch.nn.Conv1d(chan_num, chan_num, kernel_size=2)
def forward(self, x):
batch_size, channels = x.size()[:2]
x_flat = x.view(batch_size * channels, -1).transpose(0, 1)
g_y = torch.sigmoid(self.g_conv1d(x_flat))
g_y = g_y.transpose(0, 1).contiguous().view(batch_size, channels, -1)
transformed_x = x * g_y
if hasattr(self, 'f_conv1d'):
f_y = ...
transformed_x += mean * f_y
return transformed_x
class CNSN(torch.nn.Module):
def __init__(self, cross_norm, self_norm):
super().__init__()
self.cross_norm = cross_norm
self.self_norm = self_norm
def forward(self, x):
if self.training:
x = self.cross_norm(x)
x_normalized = self.self_norm(x)
return x_normalized
7. 模块参数与激活控制
- CrossNorm的active属性:决定是否执行跨实例变换,在训练时通常设为True。
- SelfNorm的通道数匹配:需确保输入张量的通道数与卷积层一致,否则会导致维度不匹配。
8. 性能考量
- CrossNorm由于涉及特征图的操作,可能会增加计算开销。建议在性能允许的情况下使用。
- SelfNorm中的1维卷积相对高效,主要消耗为矩阵乘法和归一化操作。