基于 PyTorch 完全从零手搓 GPT 混合专家 (MOE) 对话模型
一、基于 PyTorch 从零手搓 GPT 混合专家 (MOE) 对话模型
混合专家模型(MOE
)是一种 Transformer
神经网络架构的变种,如 Switch Transformers
结构 ,它通过一个门控网络为每个输入动态地选择一小部分 “专家” 子网络进行计算,从而以稀疏激活的方式提升模型容量与计算效率。能够控制模型总参数量极大的情况下,单次前向传播的计算能保持在一个可控范围内。核心特点在于其 高参数、低计算
的稀疏性。与稠密模型在处理每个输入时激活所有参数不同,MOE
模型仅激活总参数的一小部分 ,并且能够随着专家的增加容纳更加丰富的知识和更强的泛化能力。像 Mixtral 8*7B
以及 现在比较火爆的 DeepSeek
都是采用的 MOE
架构,足以证明 MOE
架构的强大潜力。
MOE
架构与传统的密集型Transformer Decoder
架构形成了鲜明对比。普通 Transformer Decoder
层通常由多头自注意力机制 MultiHeadAttention
和前馈神经网络FFN
构成。这种设计简洁、稳定、易于并行化,在 GPT、BART
等模型中都广泛应用。其计算与参数激活是全量的,即每个输入 token
都会激活整个 FFN
层的所有参数,这样有个缺点就是模型扩展时计算成本线性增长。
而 MOE
架构则保留了自注意力模块,但将前馈神经网络FFN
替换为了 专家混合 模块,也就是 MOE
层。该模块包含一个轻量级的路由门控网络 Router
和 n
个专家网络 Experts
。其中 Router
负责为每个输入 token
动态分配至 Top-K
个专家网络,专家网络通常和前馈神经网络FFN
类似,未被选中的专家会被跳过计算,从而实现 稀疏激活 。
在本专栏的前面文章中,我介绍了 从零手搓一个GPT Transformer 对话大模型 ,其中整体使用的就是传统的 Transformer Decoder
架构,文章地址:
基于 PyTorch 从零手搓一个GPT Transformer 对话大模型
在这篇文章中,从零构建了 GPTModel
网络结构,以及从零构建词表,虽然总参数量只有 三千七百多万 ,不能称之为“大模型”,但是整体架构十分具有学习意义,本文就在这篇文章的基础上重新构建网络架构,改为 MOE
混合专家架构所使用的训练数据集和词表就不再重复说明,直接都复用上篇文章的内容。
还有对于细节的 点积注意力层、多头注意力层、倒三角掩码器、位置编码 等等的计算过程和公式也都请参考上篇文章中的介绍,本篇内容最后实现的效果如下所示:
实验所使用的主要依赖版本如下:
torch==2.6.0
tensorboard==2.19.0
二、搭建 GPTMoEModel 网络架构
2.1 实现(点积计算、多头注意力机制 )
点积计算、多头注意力机制 实现逻辑和上篇文章中一致,如下所示,其中关键部分都做了注释说明:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import numpy as np# 点积计算
class ScaledDotProductAttention(nn.Module):def __init__(self, d_k):super(ScaledDotProductAttention, self).__init__()self.d_k = d_kdef forward(self, q, k, v, attention_mask):### q: [batch_size, n_heads, len_q, d_k]# k: [batch_size, n_heads, len_k, d_k]# v: [batch_size, n_heads, len_v, d_v]# attn_mask: [batch_size, n_heads, seq_len, seq_len]### 计算每个Q与K的分数,计算出来的大小是 [batch_size, n_heads, len_q, len_q]scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.d_k)# 把被mask的地方置为无限小,softmax之后基本就是0,也就对q不起作用scores.masked_fill_(attention_mask, -1e9)attn = nn.Softmax(dim=-1)(scores)# 注意力后的大小 [batch_size, n_heads, len_q, d_v]context = torch.matmul(attn, v)return context, attn# 多头注意力机制
class MultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads, d_k, d_v):super(MultiHeadAttention, self).__init__()self.d_model = d_modelself.n_heads = n_headsself.d_k = d_kself.d_v = d_vself.w_q = nn.Linear(d_model, d_k * n_heads, bias=False)self.w_k = nn.Linear(d_model, d_k * n_heads, bias=False)self.w_v = nn.Linear(d_model, d_v * n_heads, bias=False)self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)self.layernorm = nn.LayerNorm(d_model)def forward(self, q, k, v, attention_mask):### q: [batch_size, seq_len, d_model]# k: [batch_size, seq_len, d_model]# v: [batch_size, seq_len, d_model]# attn_mask: [batch_size, seq_len, seq_len]### 记录原始值, 后续计算残差residual, batch_size = q, q.size(0)# 先映射 q、k、v, 然后后分头;# q: [batch_size, n_heads, len_q, d_k]q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)# k: [batch_size, n_heads, len_k, d_k]k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)# v: [batch_size, n_heads, len_v(=len_k), d_v]v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)# attn_mask : [batch_size, n_heads, seq_len, seq_len]attention_mask = attention_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)# 点积注意力分数计算, [batch_size, n_heads, len_q, d_v]context, attn = ScaledDotProductAttention(self.d_k)(q, k, v, attention_mask)# context: [batch_size, len_q, n_heads * d_v]context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v)# 还原为原始大小output = self.fc(context)# LN + 残差计算return self.layernorm(output + residual), attn
2.2 实现门控网络Router
门控网络就是一个轻量级的神经网络,它的作用:对每一个 token
,预测其应被分配给哪些专家,并为每个选中的专家分配一个权重,用于加权融合多个专家的输出。
但是门控网络有个问题就是可能会发生 专家失衡 ,总是将样本分配给少数几个能力强或初始化的好的专家,导致其他专家得不到训练,最终整个系统退化,只有少数专家被使用。为了解决这个问题,可以在路由时,增加一个可训练的噪声,另外还需要引入一个辅助损失,也就是负载均衡损失,这里负载均衡损失参考 Mixtral
模型的做法。
实现逻辑如下:
# 门控网络
class Router(nn.Module):def __init__(self, d_model, num_experts, top_k=2):super(Router, self).__init__()self.num_experts = num_expertsself.top_k = top_kself.gate = nn.Linear(d_model, num_experts)# 用于负载均衡的噪声self.noise_linear = nn.Linear(d_model, num_experts)def forward(self, x):logits = self.gate(x)# 训练时添加噪声if self.training:noise = torch.randn_like(logits).to(x.device)noise = self.noise_linear(x) * noisenoisy_logits = logits + noiseelse:noisy_logits = logitsgates_prob = F.softmax(noisy_logits, dim=-1)# Top-k 选择top_k_probs, top_k_indices = torch.topk(gates_prob, self.top_k, dim=-1)# 归一化,确保被选中的专家的权重之和为1top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)# 负载均衡损失load_balancing_loss = self.compute_load_balancing_loss(gates_prob, top_k_indices)return top_k_probs, top_k_indices, load_balancing_lossdef compute_load_balancing_loss(self, gates_prob, top_k_indices):""" 负载均衡损失:num_experts * sum ( 每个专家的平均概率 * 每个专家选中的概率 )"""batch_size, seq_len, _ = gates_prob.shape# 计算每个专家的平均概率router_prob_per_expert = gates_prob.mean(dim=(0, 1))# 计算每个专家理想被分配到的概率expert_mask = torch.zeros_like(gates_prob)expert_mask.scatter_(2, top_k_indices, 1)tokens_per_expert = expert_mask.float().mean(dim=(0, 1))# 辅助损失return self.num_experts * torch.sum(tokens_per_expert * router_prob_per_expert)
2.3 实现专家网络
每个专家相当于是一个前馈神经网络, 这里模拟SwiGLU FFN
。
# 专家网络
class Expert(nn.Module):def __init__(self, d_model, d_ff):super(Expert, self).__init__()self.w1 = nn.Linear(d_model, d_ff, bias=False)self.w2 = nn.Linear(d_model, d_ff, bias=False)self.w_out = nn.Linear(d_ff, d_model, bias=False)def forward(self, x):return self.w_out(F.silu(self.w1(x)) * self.w2(x))
2.4 整合Router和专家层,实现 MOE 层
包括一个 门控Router
,和多个专家组成。Router
输出 top-k
专家 ID
和权重,然后将 token
输入到对应专家;然后加权融合输出
这里为了可以更加利于理解,在做专家选择时,用的双重循环 + 逐专家判断,可能无法高效的利用GPU
的并行计算,后续可以参考 Mixtral
模型的写法更高效的运行。
# MOE层
class MoELayer(nn.Module):def __init__(self, d_model, d_ff, num_experts=8, top_k=2):super(MoELayer, self).__init__()self.d_model = d_modelself.num_experts = num_expertsself.top_k = top_k# 门控路由,决定哪些专家被激活self.router = Router(d_model, num_experts, top_k)# 创建多个专家self.experts = nn.ModuleList([Expert(d_model, d_ff) for _ in range(num_experts)])# Layer Normself.layernorm = nn.LayerNorm(d_model)def forward(self, x):"""x: [batch_size, seq_len, d_model]"""residual = xbatch_size, seq_len, d_model = x.shape# 获取路由决策# gates: [batch_size, seq_len, top_k]# selected_experts: [batch_size, seq_len, top_k]gates, selected_experts, load_balancing_loss = self.router(x)# 初始化输出output = torch.zeros_like(x)# 对每个token应用选中的专家for i in range(self.top_k):# 获取当前专家索引expert_idx = selected_experts[:, :, i] # [batch_size, seq_len]# 获取当前权重expert_gate = gates[:, :, i] # [batch_size, seq_len]# 对每个专家进行计算for expert_id in range(self.num_experts):# 找出选择了当前专家的token位置mask = (expert_idx == expert_id).unsqueeze(-1) # [batch_size, seq_len, 1]if mask.any():# 获取分配给当前专家的tokensexpert_input = x * mask # [batch_size, seq_len, d_model]# 应用专家expert_output = self.experts[expert_id](expert_input) # [batch_size, seq_len, d_model]# 加权输出weighted_output = expert_output * expert_gate.unsqueeze(-1) * maskoutput += weighted_output# 残差连接和Layer Normoutput = self.layernorm(output + residual)return output, load_balancing_loss
2.5 实现 MOE 解码层
和传统的 Transformer Decoder Layer
类似,只需将 前馈网络FFN
换成 MOE
层。
# 解码层
class MoEDecoderLayer(nn.Module):def __init__(self, d_model, n_heads, d_ff, d_k, d_v, num_experts=8, top_k=2):super(MoEDecoderLayer, self).__init__()# 多头注意力层self.attention = MultiHeadAttention(d_model, n_heads, d_k, d_v)# MoEself.pos_ffn = MoELayer(d_model, d_ff, num_experts, top_k)def forward(self, inputs, attention_mask):# 多头注意力outputs, self_attn = self.attention(inputs, inputs, inputs, attention_mask)# MoEoutputs, load_balancing_loss = self.pos_ffn(outputs)return outputs, self_attn, load_balancing_loss
2.6 堆积MOE解码层,实现 MOE 解码器
将多个解码层堆叠,形成一个特征提取链。为了便于和上篇文章做效果对比,这里位置编码依然使用 GPT2
的做法,同样也需要一个倒三角掩码器,防止模型看到未来的信息。
掩码过程如下所示:
原始注意力分数矩阵(无掩码):
[[q1k1, q1k2, q1k3, q1k4],[q2k1, q2k2, q3k3, q3k4],[q3k1, q3k2, q3k3, q3k4],[q4k1, q4k2, q4k3, q4k4]]上三角掩码器:
[[0, 1, 1, 1],[0, 0, 1, 1],[0, 0, 0, 1],[0, 0, 0, 0]]应用掩码后的分数矩阵:
[[q1k1, -inf, -inf, -inf],[q2k1, q2k2, -inf, -inf],[q3k1, q3k2, q3k3, -inf],[q4k1, q4k2, q4k3, q4k4]]
实现逻辑如下:
# 位置编码,这里使用GPT2的做法
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_pos, device):super(PositionalEncoding, self).__init__()self.device = deviceself.pos_embedding = nn.Embedding(max_pos, d_model)def forward(self, inputs):seq_len = inputs.size(1)pos = torch.arange(seq_len, dtype=torch.long, device=self.device)# [seq_len] -> [batch_size, seq_len]pos = pos.unsqueeze(0).expand_as(inputs)return self.pos_embedding(pos)# 获取pad掩码器
def get_attn_pad_mask(attention_mask):batch_size, len_seq = attention_mask.size()attention_mask = attention_mask.data.eq(0).unsqueeze(1)# 注意力分数的大小是 [batch_size, n_heads, len_q, len_q]# 所以这里要转换成 [batch_size, len_seq, len_seq] 大小return attention_mask.expand(batch_size, len_seq, len_seq)# 获取倒三角掩码器,防止模型看到未来的信息
def get_attn_subsequence_mask(seq, device):# 注意力分数的大小是 [batch_size, n_heads, len_seq, len_seq]# 所以这里要生成 [batch_size, len_seq, len_seq] 大小attn_shape = [seq.size(0), seq.size(1), seq.size(1)]# 生成一个上三角矩阵subsequence_mask = np.triu(np.ones(attn_shape), k=1)subsequence_mask = torch.from_numpy(subsequence_mask).byte()subsequence_mask = subsequence_mask.to(device)return subsequence_mask# 解码器
class MoEDecoder(nn.Module):def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,device, num_experts=8, top_k=2):super(MoEDecoder, self).__init__()self.device = device# 将Token转为向量self.embedding = nn.Embedding(vocab_size, d_model)# 位置编码self.pos_encoding = PositionalEncoding(d_model, max_pos, device)# 创建MOE层self.layers = nn.ModuleList()for i in range(n_layers):self.layers.append(MoEDecoderLayer(d_model, n_heads, d_ff, d_k, d_v,num_experts, top_k))def forward(self, inputs, attention_mask):# 嵌入和位置编码outputs = self.embedding(inputs) + self.pos_encoding(inputs)# 生成掩码subsequence_mask = get_attn_subsequence_mask(inputs, self.device)if attention_mask is not None:attention_mask = get_attn_pad_mask(attention_mask)attention_mask = torch.gt((attention_mask + subsequence_mask), 0)else:attention_mask = subsequence_mask.bool()# 计算每一层的结果self_attns = []total_load_balancing_loss = 0.0for layer in self.layers:layer_output = layer(outputs, attention_mask)outputs, self_attn, load_balancing_loss = layer_outputtotal_load_balancing_loss += load_balancing_lossself_attns.append(self_attn)return outputs, self_attns, total_load_balancing_loss
2.7 整合解码器,实现 GPTMoEModel
这里需要注意,损失函数要考虑前面的负载均衡损失,因此整体的损失应该是两者之和。
# GPT MOE模型
class GPTMoEModel(nn.Module):def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,device, num_experts=8, top_k=2, load_balancing_weight=0.01):super(GPTMoEModel, self).__init__()self.load_balancing_weight = load_balancing_weight# 解码器self.decoder = MoEDecoder(d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,device, num_experts, top_k)# 映射为词表大小self.projection = nn.Linear(d_model, vocab_size)def forward(self, inputs, attention_mask=None, targets=None):# 前向传播outputs, self_attns, load_balancing_loss = self.decoder(inputs, attention_mask)# 投影到词表logits = self.projection(outputs)logits = logits.view(-1, logits.size(-1))if targets is not None:# 负载均衡损失load_balancing_loss = load_balancing_loss * self.load_balancing_weight# 任务损失lm_loss = F.cross_entropy(logits, targets.view(-1), ignore_index=0)# MOE架构的总损失是任务损失和负载均衡损失的加权和total_loss = lm_loss + load_balancing_lossreturn logits, self_attns, total_lossreturn logits, self_attns
2.8 整体网络架构
以上整体网络代码放在 model_moe.py
中。
import torch
from model_moe import GPTMoEModeldef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 模型参数model_param = {"d_model": 768, # 嵌入层大小"d_ff": 2048, # 这是为专家网络大小"d_k": 64, # K 的大小"d_v": 64, # V 的大小"n_layers": 6, # 解码层的数量"n_heads": 8, # 多头注意力的头数"max_pos": 1800, # 位置编码的长度"device": device, # 设备"vocab_size": 4825, # 词表大小,上篇文章中构建的词表大小"num_experts": 8, # 8个专家"top_k": 2, # 每个token选择2个专家"load_balancing_weight": 0.01 # 负载均衡损失权重}model = GPTMoEModel(**model_param)total_params = sum(p.numel() for p in model.parameters())print(model)print("total_params: ", total_params)if __name__ == '__main__':main()
执行输出:
GPTMoEModel((decoder): MoEDecoder((embedding): Embedding(4825, 768)(pos_encoding): PositionalEncoding((pos_embedding): Embedding(1800, 768))(layers): ModuleList((0-5): 6 x MoEDecoderLayer((attention): MultiHeadAttention((w_q): Linear(in_features=768, out_features=512, bias=False)(w_k): Linear(in_features=768, out_features=512, bias=False)(w_v): Linear(in_features=768, out_features=512, bias=False)(fc): Linear(in_features=512, out_features=768, bias=False)(layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True))(pos_ffn): MoELayer((router): Router((gate): Linear(in_features=768, out_features=8, bias=True)(noise_linear): Linear(in_features=768, out_features=8, bias=True))(experts): ModuleList((0-7): 8 x Expert((fc1): Linear(in_features=768, out_features=2048, bias=False)(fc2): Linear(in_features=2048, out_features=768, bias=False)(activation): ReLU()))(layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)))))(projection): Linear(in_features=768, out_features=4825, bias=True)
)
total_params: 173028409
整体参数量为 1.73亿
, 0.17B
大小,相比上篇文章构建的网络,能容纳更多的知识。
三、模型训练
这里训练集和训练过程基本和上篇文章一致,同时训练数据集中同样增加一些自定义的模型特色内容,追加几条身份的数据在里面:
{"question": "你是谁", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你叫什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的名字是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你叫啥", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你名字是啥", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你是什么身份", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的全名是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你自称什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的称号是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的昵称是什么", "answer": "我是小毕超,一个简易的小助手"}
3.1 构建 Dataset
qa_dataset.py
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as npclass QADataset(Dataset):def __init__(self, data_path, tokenizer, max_length) -> None:super().__init__()self.tokenizer = tokenizerself.max_length = max_lengthself.data = []if data_path:with open(data_path, "r", encoding='utf-8') as f:for line in f:if not line or line == "":continuejson_line = json.loads(line)question = json_line["question"]answer = json_line["answer"]self.data.append({"question": question,"answer": answer})print("data load , size:", len(self.data))def preprocess(self, question, answer):encode, att_mask = self.tokenizer.encode(question, answer, max_length=self.max_length, pad_to_max_length=True)input_ids = encode[:-1]att_mask = att_mask[:-1]labels = encode[1:]return input_ids, att_mask, labelsdef __getitem__(self, index):item_data = self.data[index]input_ids, att_mask, labels = self.preprocess(**item_data)return {"input_ids": torch.LongTensor(np.array(input_ids)),"attention_mask": torch.LongTensor(np.array(att_mask)),"labels": torch.LongTensor(np.array(labels))}def __len__(self):return len(self.data)
3.2 训练
# -*- coding: utf-8 -*-
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tokenizer import Tokenizer
from model_moe import GPTMoEModel
from qa_dataset import QADataset
from tqdm import tqdm
import time, sys, osdef train_model(model, train_loader, val_loader, optimizer,device, num_epochs, model_output_dir, writer):batch_step = 0best_val_loss = float('inf')for epoch in range(num_epochs):time1 = time.time()model.train()for index, data in enumerate(tqdm(train_loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):input_ids = data['input_ids'].to(device, dtype=torch.long)attention_mask = data['attention_mask'].to(device, dtype=torch.long)labels = data['labels'].to(device, dtype=torch.long)optimizer.zero_grad()outputs, dec_self_attns, loss = model(input_ids, attention_mask, labels)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), 1)optimizer.step()writer.add_scalar('Loss/train', loss, batch_step)batch_step += 1# 50轮打印一次 lossif index % 50 == 0 or index == len(train_loader) - 1:time2 = time.time()tqdm.write(f"{index}, epoch: {epoch} -loss: {str(loss)} ; lr: {optimizer.param_groups[0]['lr']} ;each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")# 验证model.eval()val_loss = validate_model(model, device, val_loader)writer.add_scalar('Loss/val', val_loss, epoch)print(f"val loss: {val_loss} , epoch: {epoch}")# 保存最优模型if val_loss < best_val_loss:best_val_loss = val_lossbest_model_path = os.path.join(model_output_dir, "best.pt")print("Save Best Model To ", best_model_path, ", epoch: ", epoch)torch.save(model.state_dict(), best_model_path)# 保存当前模型last_model_path = os.path.join(model_output_dir, "last.pt")print("Save Last Model To ", last_model_path, ", epoch: ", epoch)torch.save(model.state_dict(), last_model_path)def validate_model(model, device, val_loader):running_loss = 0.0with torch.no_grad():for _, data in enumerate(tqdm(val_loader, file=sys.stdout, desc="Validation Data")):input_ids = data['input_ids'].to(device, dtype=torch.long)attention_mask = data['attention_mask'].to(device, dtype=torch.long)labels = data['labels'].to(device, dtype=torch.long)outputs, dec_self_attns, loss = model(input_ids, attention_mask, labels)running_loss += loss.item()return running_loss / len(val_loader)def main():train_json_path = "data/train.json" # 训练集val_json_path = "data/val.json" # 验证集vocab_path = "data/vocab.json" # 词表位置max_length = 120 # 最大长度epochs = 15 # 迭代周期batch_size = 128 # 训练一个批次的大小lr = 1e-4 # 学习率model_output_dir = "output" # 模型保存目录logs_dir = "logs" # 日志记录目标# 设备device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 加载分词器tokenizer = Tokenizer(vocab_path)# 模型参数model_param = {"d_model": 768, # 嵌入层大小"d_ff": 2048, # 专家网络大小"d_k": 64, # K 的大小"d_v": 64, # V 的大小"n_layers": 6, # 解码层的数量"n_heads": 8, # 多头注意力的头数"max_pos": 1800, # 位置编码的长度"device": device, # 设备"vocab_size": tokenizer.get_vocab_size(), # 词表大小"num_experts" :8, # 8个专家"top_k" : 2, # 每个token选择2个专家"load_balancing_weight" : 0.01 # 负载均衡损失权重}model = GPTMoEModel(**model_param)print("Start Load Train Data...")train_params = {"batch_size": batch_size,"shuffle": True,"num_workers": 4,}training_set = QADataset(train_json_path, tokenizer, max_length)training_loader = DataLoader(training_set, **train_params)print("Start Load Validation Data...")val_params = {"batch_size": batch_size,"shuffle": False,"num_workers": 4,}val_set = QADataset(val_json_path, tokenizer, max_length)val_loader = DataLoader(val_set, **val_params)# 日志记录writer = SummaryWriter(logs_dir)# 优化器optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)model = model.to(device)# 开始训练print("Start Training...")train_model(model=model,train_loader=training_loader,val_loader=val_loader,optimizer=optimizer,device=device,num_epochs=epochs,model_output_dir=model_output_dir,writer=writer)writer.close()if __name__ == '__main__':main()
训练过程:
训练结果后使用 tensorboard
查看下 loss
趋势:
在训练 15
个epochs
情况下,验证集的 loss
,在前 9
个 epochs
一直处于下降趋势,第10
个epochs
开始上升,考虑出现过拟合情况,后续优化可以在网络中加入部分 dropout
来随机失活。
四、模型预测使用测试
import torchfrom model_moe import GPTMoEModel
from tokenizer import Tokenizerdef generate(model, tokenizer, text, max_length, device):input, att_mask = tokenizer.encode(text)input = torch.tensor(input, dtype=torch.long, device=device).unsqueeze(0)stop = Falseinput_len = len(input[0])while not stop:if len(input[0]) - input_len > max_length:next_symbol = tokenizer.sep_tokeninput = torch.cat([input.detach(), torch.tensor([[next_symbol]], dtype=input.dtype, device=device)], -1)breakprojected, self_attns = model(input)prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]next_word = prob.data[-1]next_symbol = next_wordif next_symbol == tokenizer.sep_token:stop = Trueinput = torch.cat([input.detach(), torch.tensor([[next_symbol]], dtype=input.dtype, device=device)], -1)decode = tokenizer.decode(input[0].tolist())decode = decode[len(text):]return "".join(decode)def main():model_path = "output/last.pt"vocab_path = "data/vocab.json" # 词表位置max_length = 120 # 最大长度device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 加载分词器tokenizer = Tokenizer(vocab_path)# 模型参数model_param = {"d_model": 768, # 嵌入层大小"d_ff": 2048, # 专家网络大小"d_k": 64, # K 的大小"d_v": 64, # V 的大小"n_layers": 6, # 解码层的数量"n_heads": 8, # 多头注意力的头数"max_pos": 1800, # 位置编码的长度"device": device, # 设备"vocab_size": tokenizer.get_vocab_size(), # 词表大小"num_experts": 8, # 8个专家"top_k": 2, # 每个token选择2个专家"load_balancing_weight": 0.01 # 负载均衡损失权重}model = GPTMoEModel(**model_param)model.load_state_dict(torch.load(model_path))model.to(device)while True:text = input("请输入:")if not text:continueif text == "q":breakres = generate(model, tokenizer, text, max_length, device)print("AI: ", res)if __name__ == '__main__':main()
预测效果:
五、总结
文本仅对MOE
架构做了下的实验,其中还有很多可以优化的地方,例如可以使用RoPE
旋转位置编码、加入 RMSNormal
、尝试更先进的路由策略、加入 dropout
等等,后续你可以继续尝试进行改造和优化。