【大模型手撕】pytorch实现LayerNorm, RMSNorm
LayerNorm介绍请参考:【AI知识】归一化、批量归一化 、 层归一化 和 实例归一化
RMSNorm介绍请参考:【大模型知识点】RMSNorm(Root Mean Square Normalization)均方根归一化
LayerNorm实现:
import torch
import torch.nn as nnclass LayerNorm(nn.Module):def __init__(self,dim,eps=1e-5,bias=False):super().__init__()self.dim = dimself.eps = eps# 可训练的缩放参数self.gamma = nn.Parameter(torch.ones(dim))self.bias = nn.Parameter(torch.zeros(dim)) if bias else Nonedef forward(self,x):# x: (batch_size,seq_len,dim)# 计算均值 x_mean : (batch_size,seq_len,dim)x_mean = x.mean(-1,keepdim=True)# 计算均方根 rms : (batch_size,seq_len,dim)rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)if self.bias:return self.gamma*((x-x_mean)/rms)+self.biaselse:return self.gamma*((x-x_mean)/rms)
RMSNorm实现:
import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self,dim,eps=1e-5,bias=False):super().__init__()self.dim = dim self.eps = eps# 可训练的缩放参数self.gamma = nn.Parameter(torch.ones(dim))self.bias = nn.Parameter(torch.zeros(dim)) if bias else Nonedef forward(self,x):# 计算输入的均方根# x: (batch_size,seq_len,dim)# .mean(-1,keepdim=True) : 在最后一个维度(特征维度)上计算平均值,并保持维度不变# rms : (batch_size,seq_len,1)rms = torch.sqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)if self.bias:return self.gamma*(x/rms) + self.biaselse:return self.gamma*(x/rms)