什么是MOE?
混合专家模型太火了,这里必须学习一下。
这次学习资料来源于chaofa老师的B站视频。
https://www.bilibili.com/video/BV1ZbFpeHEYr?spm_id_from=333.788.videopod.sections&vd_source=30e4372a35b7112cb7f7d9cbc8fbac60
下面开始。
什么是MOE?
混合专家模型没什么神秘的,就是一个大号的注意力而已。
一个token 通过不同的专家,得到不同的向量,相当于不同专家给出不同的意见。然后有一个门网络,给出不同的权重,按照权重来考虑各个专家的思想。
这种混合专家模型代码也比较好写。
#第一个混合专家
class BasicExpert(nn.Module):def __init__(self, feature_in, feature_out):super(BasicExpert, self).__init__()self.fc = nn.Linear(feature_in, feature_out)def forward(self, x):return self.fc(x)class BasicMoe(nn.Module):def __init__(self, feature_in, feature_out, num_experts):super(BasicMoe, self).__init__()self.gate = nn.Linear(feature_in, num_experts)self.experts = nn.ModuleList(BasicExpert(feature_in, feature_out)for _ in range(num_experts))def forward(self, x):expert_weights = self.gate(x)expert_out_list = [expert(x) for expert in self.experts]expert_outs = [expert_out.unsqueeze(1) for expert_out in expert_out_list]expert_out = torch.concat(expert_outs, dim = 1)expert_weights = F.softmax(expert_weights, dim=1)expert_weights = expert_weights.unsqueeze(1)output = expert_weights @ expert_outreturn output.sequeeze(1) my_base_moe = BasicMoe(128, 128, 4)x = torch.ones((512, 128))
y = my_base_moe(x)
print(y)
这里写成了矩阵相乘的形式,一堆sequeeze比较难以看懂。
但是调试一遍保证懂。
2,spaseMoe
稀疏的混合专家模型。 相当于你去医院,加了一个导诊台, 导诊台会将你分流到部分的专家那里,而不是全部的专家。 每次只激活一部分的专家模型。选择TOP-K个专家。
class MOEconfig():def __init__(self, hidden_dim, expert_number, top_k, shared_experts_number=2):self.hidden_dim = hidden_dimself.expert_number = expert_numberself.top_k = top_kself.shared_experts_number = shared_experts_numberclass MoeRouter(nn.Module):def __init__(self, config):super(MoeRouter, self).__init__()self.gate = nn.Linear(config.hidden_dim, config.expert_number)self.expert_number = config.expert_numberself.top_k = config.top_kdef forward(self, x):router_logits = self.gate(x) #(每个token产生一个8维的打分,理解为看看哪个专家最适合这个token) #router_probs = F.softmax(router_logits, dim=1, dtype=torch.float) # 得到softmax结果,其实感觉可以后面再softmaxrouter_weights, select_expert_indices = torch.topk(router_probs, self.top_k, dim=-1) #topk是可以反向传播的#token数量* top_k, token_数量*top_krouter_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) #继续归一化router_weights = router_weights.to(x.dtype)expert_masks = F.one_hot(select_expert_indices, num_classes=self.expert_number) #生成一个 (token数量, 专家数量, 总专家数量的矩阵) 对于每一个专家,产生一个mask,这个mask只有它在的下标是1.expert_masks = expert_masks.permute(2, 1, 0) #变成, 专家mask长度, 专家数量,token数量return router_logits, router_weights, expert_masks, select_expert_indices#这四个 分别是, 初试的专家倾向结果, 所有token归一化后的选择专家的权重, 每一个专家的mask转置后结果, 每个token选择的专家下标class Sparse_MOE(nn.Module):def __init__(self, config):super(Sparse_MOE, self).__init__()self.config = configself.top_k = config.top_kself.hidden_dim = config.hidden_dimself.expert_number = config.expert_numberself.experts = nn.ModuleList(BasicExpert(config.hidden_dim, config.hidden_dim)for _ in range(config.expert_number))self.router = MoeRouter(config)def forward(self, x):batch_size, seq_len, hidden_dim = x.size() #如果有第二维度,也就是token数量的维度,要变成上面那样hidden_states = x.view(batch_size * seq_len, hidden_dim) #变成所有token并列,一个token一个token的看。(65536*768)#做相关专家计算router_logits, router_weights, expert_masks, select_expert_indices = self.router(hidden_states)final_hidden_states = torch.zeros((batch_size*seq_len, hidden_dim),dtype = hidden_states.dtype,device = hidden_states.device)for expert_idx in range(self.expert_number): #对于每个专家进行一次循环?expert_layer = self.experts[expert_idx] # 选出每一个专家, 第一次只看专家0current_expert_mask = expert_masks[expert_idx] #只看当前专家对应的maskrouter_weight_idx, top_x = torch.where(current_expert_mask)# where 返回值为1的行索引和列索引。#idx用来选择weight, topk 选择hiddenstates( weight, 这是每个token的第几个,对应的weight hiddenstates, 选择哪些token)current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim) #很清楚的看到topx选择哪些token# shape是选中的token数量 * 对应的token向量current_state = expert_layer(current_state) #经过专家层current_token_router_weight = router_weights[top_x, router_weight_idx]#从这里可以看出来了, top_x对应的是第几个token, idx对应的是这人该token1的第几个专家,两个下标得到weightcurrent_token_router_weight = current_token_router_weight.unsqueeze(-1)current_hidden_states = current_state * current_token_router_weight #本专家对应的所有token, 和他们的向量final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)return final_hidden_states
这个代码是比较难以看懂的,他的逻辑是这样的,
router_logits = self.gate(x) #(每个token产生一个8维的打分,理解为看看哪个专家最适合这个token) #router_probs = F.softmax(router_logits, dim=1, dtype=torch.float) # 得到softmax结果,其实感觉可以后面再softmax
在moerouter中,会针对每一个token进行打分,得到每一个专家的分数,
router_weights, select_expert_indices = torch.topk(router_probs, self.top_k, dim=-1) #topk是可以反向传播的
对于每一个token,从所有专家打分中挑选出最大的几个
router_weights = router_weights.to(x.dtype)expert_masks = F.one_hot(select_expert_indices, num_classes=self.expert_number) #生成一个 (token数量, 专家数量, 总专家数量的矩阵) 对于每一个专家,产生一个mask,这个mask只有它在的下标是1.expert_masks = expert_masks.permute(2, 1, 0) #变成, 专家mask长度, 专家数量,token数量
这里生成mask,针对每一个token, 对于每一个它选择的专家,让它那个专家的下标为处值为1。
第二步比较关键,就是让token和专家换换位置,转换为专家的角度来看。
return router_logits, router_weights, expert_masks, select_expert_indices#这四个 分别是, 初试的专家倾向结果, 所有token归一化后的选择专家的权重, 每一个专家的mask转置后结果, 每个token选择的专家下标
这里返回四个结果, 第一个是所有token的所有专家倾向分数,
第二个是,选择的专家的分数,注意是只有选择的专家的。
第三个是专家的mask, 这个可以用来看使用了哪个专家。
后面也可以看,每个token选择的专家的下标。
def forward(self, x):batch_size, seq_len, hidden_dim = x.size() #如果有第二维度,也就是token数量的维度,要变成上面那样hidden_states = x.view(batch_size * seq_len, hidden_dim) #变成所有token并列,一个token一个token的看。(65536*768)#做相关专家计算router_logits, router_weights, expert_masks, select_expert_indices = self.router(hidden_states)
传入x后,先变成token序列(把batch和len乘起来),通过router,来得到选择的专家。
final_hidden_states = torch.zeros((batch_size*seq_len, hidden_dim),dtype = hidden_states.dtype,device = hidden_states.device)
初始化一个盛放最终向量的容器。
for expert_idx in range(self.expert_number): #对于每个专家进行一次循环?expert_layer = self.experts[expert_idx] # 选出每一个专家, 第一次只看专家0current_expert_mask = expert_masks[expert_idx] #只看当前专家对应的maskrouter_weight_idx, top_x = torch.where(current_expert_mask)# where 返回值为1的行索引和列索引。#idx用来选择weight, topk 选择hiddenstates( weight, 这是每个token的第几个,对应的weight hiddenstates, 选择哪些token)
这一段非常难以理解。
上面说了,这里变成了专家角度, 专家变成了第一维,
因为要实现这个效果: 如果某个token,没有选择某个专家,就不需要让这个token经过这个专家。 但是如果直接让向量通过模型,那就是需要全部计算了。
所以是这样实现的,先挑出来一个专家,看看哪些token需要通过它,然后把这些token挑出来。然后让这些token通过专家层。
选一个专家层0, 挑出对应的mask
在mask中,一个token会有多个mask向量,对应不同专家。
比如挑了两个专家,那就是向量就是2 * 专家数量
那么反过来, 选定专家后, 对应的token也会有两个向量,其中某个向量可能对应当前专家。所以用where, top_x 是该专家被哪些token选择了,router_weight_idx表示的是 这个专家是 某个token选择的第几个专家
current_state = hidden_states.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim) #很清楚的看到topx选择哪些token# shape是选中的token数量 * 对应的token向量current_state = expert_layer(current_state) #经过专家层
根据top_x选择那些token的特征,让这些特征通过当前专家。(李的疑问,这样确实省了计算资源,可是速度呢? 并行的矩阵运算要比循环快很多吧。)
current_token_router_weight = router_weights[top_x, router_weight_idx]#从这里可以看出来了, top_x对应的是第几个token, idx对应的是这人该token1的第几个专家,两个下标得到weightcurrent_token_router_weight = current_token_router_weight.unsqueeze(-1)
top_x 是该专家被哪些token选择了,router_weight_idx表示的是 这个专家是 某个token选择的第几个专家,所以根据这两个下标可以得到当前专家对应的那个权重是多少,让它和通过模型后得到的
向量相乘
current_hidden_states = current_state * current_token_router_weight #本专家对应的所有token, 和他们的向量final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
相乘后得到最终向量,然后加入到final向量中。根据top-k下标加入到对应位置,
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
再把batch拆分出来,结束。
3, deepseek的混合专家,share expert。
有一些固定的专家,每次都会固定被激活
class SharedMOE(nn.Module):def __init__(self, config):super(SharedMOE, self).__init__()self.config = configself.router_experts_moe = Sparse_MOE(config)self.shared_experts = nn.ModuleList([BasicExpert(self.config.hidden_dim, self.config.hidden_dim)for _ in range(config.shared_experts_number)])def forward(self, x):batch_size, seq_len, hidden_dim = x.size()shared_experts_out_list = [expert(x) for expert in self.shared_experts]shared_experts_output = torch.stack(shared_experts_out_list, dim=0)shared_experts_output = shared_experts_output.sum(dim=0)sparse_moe_out, router_logits = self.router_experts_moe(x)output = shared_experts_output + sparse_moe_out #这里的权重还值得商议return output, router_logits
这是一个sharedmoe 看起来就是额外加了一些共享的专家,值得注意的是,这里没有规定共享专家和选择专家的权重
此外最后还要返回router_logits 这和专家的选择相关,因为在loss中会有相关的计算。 让routerlogits相对均衡,使得各个专家被选择的比例也相对均衡。