1. 辅助函数
calc_ins_mean_std(x, eps=1e-12)
功能 :计算每张图像每个通道的均值和标准差。输入 :一个4维张量(batch_size x channels x height x width)。输出 :两个3维张量,分别表示均值和标准差。均值的维度为 (batch_size, channels, 1)。
instance_norm_mix(content_feature, style_feature)
功能 :将内容特征归一化到风格特征的分布下。步骤 : 计算content_feat和style_feat的均值和标准差。 使用style_feat的统计量对content_feat进行变换。
cn_rand_bbox(x, use_crop=True)
功能 :生成随机边界框,用于crop操作。参数 :x为输入图像,use_crop决定是否应用 crop。返回 :如果use_crop为True,按比例缩放和裁剪图像,并返回新尺寸。
cn_op_2ins_space_CHAN(x, use_crop=True)
功能 :跨实例或空间变换特征。参数 :输入x、是否使用crop。步骤 : 如果使用 crop,则生成随机边界框并裁剪图像。 将输入在通道维度上分割,交换特征图块,或按比例缩放。
2. 模块实现
CrossNorm
结构 :包含一个操作函数cn_op_2ins_space_CHAN
,用于跨实例的统计交换。特点 : 在训练时active为True时才应用。 可选参数影响是否应用crop和空间变换。
SelfNorm
结构 :全连接层结合BN层,生成通道特定的缩放和平移参数。特点 : 使用两个1维卷积层(g_fc
和可选的f_fc
)来生成调整因子。 应用sigmoid激活函数限制范围。
CNSN
结构 :组合CrossNorm和SelfNorm模块。流程 :先应用CrossNorm(如果active),然后使用SelfNorm自适应调整特征。
3. 使用示例
crossnorm = CrossNorm( )
selfnorm = SelfNorm( chan_num= 3 )
block = CNSN( crossnorm, selfnorm) input = torch. rand( 32 , 3 , 224 , 224 )
output = block( input )
print ( output. size( ) )
4. 总结
主要功能 :CrossNorm和SelfNorm模块通过自适应调整特征图的统计量,增强模型对分布偏移的鲁棒性。适用场景 :在训练阶段使用这些模块可以提高模型泛化能力,而无需修改网络结构。
5. 注意事项
使用时需确保数据预处理标准化已正确完成,否则自适应调整可能无效。 CrossNorm应仅在训练或特定需要的阶段启用,以避免推理时计算开销过大。 SelfNorm提供了对每个通道更细致的控制,适用于需要通道间较强独立性的任务。
6. 代码解释
import torchdef calc_ins_mean_std ( x, eps= 1e-12 ) : mean = x. mean( dim= ( 0 , 2 , 3 ) , keepdim= True ) std = x. std( dim= ( 0 , 2 , 3 ) , keepdim= True ) . add( eps) return mean, stddef instance_norm_mix ( content_feature, style_feature) : content_mean, content_std = calc_ins_mean_std( content_feature) style_mean, style_std = calc_ins_mean_std( style_feature) normalized_content = ( content_feature - content_mean) / content_stdtransformed_content = normalized_content * style_std + style_meanreturn transformed_contentdef cn_rand_bbox ( x, use_crop= True ) : if not use_crop: return x, None _, _, h, w = x. size( ) size = int ( h // 2 ) boxes = [ ] for i in range ( 0 , h, size) : for j in range ( 0 , w, size) : if i + size > h or j + size > w: continue boxes. append( ( i, j, i+ size, j+ size) ) return x, None def cn_op_2ins_space_CHAN ( x, use_crop= True ) : if not use_crop: return xelse : x, _ = cn_rand_bbox( x, use_crop) return xclass CrossNorm ( torch. nn. Module) : def __init__ ( self) : super ( ) . __init__( ) def forward ( self, x) : if self. training and self. active: transformed_x = cn_op_2ins_space_CHAN( x) return transformed_xelse : return xclass SelfNorm ( torch. nn. Module) : def __init__ ( self, chan_num) : super ( ) . __init__( ) self. g_conv1d = torch. nn. Conv1d( chan_num, Chan_num, kernel_size= 2 ) self. bn = torch. nn. BatchNorm1d( Chan_num) if is_two: self. f_conv1d = torch. nn. Conv1d( chan_num, chan_num, kernel_size= 2 ) def forward ( self, x) : batch_size, channels = x. size( ) [ : 2 ] x_flat = x. view( batch_size * channels, - 1 ) . transpose( 0 , 1 ) g_y = torch. sigmoid( self. g_conv1d( x_flat) ) g_y = g_y. transpose( 0 , 1 ) . contiguous( ) . view( batch_size, channels, - 1 ) transformed_x = x * g_yif hasattr ( self, 'f_conv1d' ) : f_y = . . . transformed_x += mean * f_yreturn transformed_xclass CNSN ( torch. nn. Module) : def __init__ ( self, cross_norm, self_norm) : super ( ) . __init__( ) self. cross_norm = cross_normself. self_norm = self_normdef forward ( self, x) : if self. training: x = self. cross_norm( x) x_normalized = self. self_norm( x) return x_normalized
7. 模块参数与激活控制
CrossNorm的active属性 :决定是否执行跨实例变换,在训练时通常设为True。SelfNorm的通道数匹配 :需确保输入张量的通道数与卷积层一致,否则会导致维度不匹配。
8. 性能考量
CrossNorm由于涉及特征图的操作,可能会增加计算开销。建议在性能允许的情况下使用。 SelfNorm中的1维卷积相对高效,主要消耗为矩阵乘法和归一化操作。