batchnorm1d,layernorm,revin区别
注意:
batchnorm1d是指同一特征下,所有batch的所有序列进行求均值标准差归一化
layernorm是指每个序列的所有特征进行求均值标准差归一化
revin是指同一batch,同一特征下,所有序列进行均值标准差归一化,即样本归一
class RevIN(nn.Module):def __init__(self, num_features: int, eps=1e-5, affine=True):""":param num_features: the number of features or channels:param eps: a value added for numerical stability:param affine: if True, RevIN has learnable affine parameters"""super(RevIN, self).__init__()self.num_features = num_featuresself.eps = epsself.affine = affineif self.affine:self._init_params()def forward(self, x, mode:str):if mode == 'norm':self._get_statistics(x)x = self._normalize(x)elif mode == 'denorm':x = self._denormalize(x)else: raise NotImplementedErrorreturn xdef _init_params(self):# initialize RevIN params: (C,)self.affine_weight = nn.Parameter(torch.ones(self.num_features))self.affine_bias = nn.Parameter(torch.zeros(self.num_features))def _get_statistics(self, x):dim2reduce = tuple(range(1, x.ndim-1))self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()def _normalize(self, x):x = x - self.meanx = x / self.stdevif self.affine:x = x * self.affine_weightx = x + self.affine_biasreturn xdef _denormalize(self, x):if self.affine:x = x - self.affine_biasx = x / (self.affine_weight + self.eps*self.eps)x = x * self.stdevx = x + self.meanreturn x