厦门网站制作公司抖音seo搜索优化
1. 自定义模型组件
MultiHeadAttention
类
-
实现了多头自注意力机制。
-
通过将输入分割成多个“头”,从不同角度学习输入数据的特征。
-
注意力分数计算后应用了缩放点积注意力,并支持掩码操作。
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout):super().__init__()self.num_heads = num_headsself.d_k = d_model // num_headsself.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)self.o_proj = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):batch_size, seq_len, d_model = x.shapeQ = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:mask = mask.unsqueeze(1).unsqueeze(1)atten_scores = atten_scores.masked_fill(mask == 0, -1e9)atten_scores = torch.softmax(atten_scores, dim=-1)out = atten_scores @ Vout = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)return self.dropout(self.o_proj(out))
FeedForward
类
-
实现了Transformer中的前馈网络(Feed-Forward Network, FFN)。
class FeedForward(nn.Module):def __init__(self, d_model, dff, dropout):super().__init__()self.W1 = nn.Linear(d_model, dff)self.act = nn.GELU()self.W2 = nn.Linear(dff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.W2(self.dropout(self.act(self.W1(x))))
TransformerEncoderBlock
类
class TransformerEncoderBlock(nn.Module):def __init__(self, d_model, num_heads, dropout, dff):super().__init__()self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)self.ffn_block = FeedForward(d_model, dff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)def forward(self, x, mask=None):res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))return res2
2. BERT模型
BERT
类
-
实现了一个简化版的BERT模型,包括词嵌入、位置嵌入和多个Transformer编码器块。
-
没有实现段嵌入(Segment Embedding),因为文本分类任务通常不需要区分句子对。
class BERT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block):super().__init__()self.seq_len = seq_lenself.tok_emb = nn.Embedding(vocab_size, d_model)self.pos_emb = nn.Embedding(seq_len, d_model)self.layers = nn.ModuleList([TransformerEncoderBlock(d_model, num_heads, dropout, dff)for _ in range(N_block)])def forward(self, x, mask):position = torch.arange(0, self.seq_len)tok_emb = self.tok_emb(x)pos_emb = self.pos_emb(position)x = tok_emb + pos_embfor layer in self.layers:x = layer(x, mask)return x
BertClassification
类
-
在BERT模型的基础上添加了一个分类头(
self.cls
),用于文本分类任务。 -
分类头是一个线性层,将BERT模型的输出映射到类别数量。
class BertClassification(nn.Module):def __init__(self, vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block, classes):super().__init__()self.bert = BERT(vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block)self.cls = nn.Linear(d_model, classes)def forward(self, x, mask):bert_out = self.bert(x, mask) # (batch_size, seq_len, d_model)cls_repr = bert_out[:, 0, :] # return self.cls(cls_repr)
3. 数据处理
read_data
函数
-
读取数据文件,解析出标题和标签。
-
使用
collections.defaultdict
按类别组织数据。 -
随机采样每个类别的标题,确保数据平衡。
def read_data(file, n=200):with open(file, "r", encoding="utf-8") as f:data = f.read().strip().split("\n")category_data = collections.defaultdict(list)for line in data:segs = line.split("_!_")label = int(segs[1])title = segs[3]category_data[LABELS_MAP[label]].append(title)sampled_titles = []sampled_labels = []for label, titles in category_data.items():samples = random.sample(titles, min(n, len(titles)))sampled_labels.extend([label]*len(samples))sampled_titles.extend(samples)return sampled_titles, sampled_labels
Toutiao
类
-
继承自
torch.utils.data.Dataset
,用于加载和处理数据。 -
使用
BertTokenizer
将文本转换为BERT模型可以处理的格式。 -
返回的每个样本包括标题文本、标题ID、注意力掩码和标签。
class Toutiao(Dataset):def __init__(self, titles, labels, tokenizer: BertTokenizer, seq_len):self.titles = titlesself.labels = labelsself.seq_len = seq_lenself.tokenzier = tokenizerdef __len__(self):return len(self.titles)def __getitem__(self, idx):title_txt = self.titles[idx]label = self.labels[idx]tokenizer_out = self.tokenzier(title_txt,padding="max_length",max_length=self.seq_len,truncation=True)title_ids = tokenizer_out["input_ids"]input_msk = tokenizer_out["attention_mask"]return {"title_txt": title_txt,"title_ids": torch.tensor(title_ids),"input_msk": torch.tensor(input_msk),"label": torch.tensor(label),}
4. 训练过程
数据加载
-
使用
random_split
将数据集分为训练集和验证集。 -
使用
DataLoader
加载数据,训练集使用随机打乱,验证集不打乱。
模型训练
-
定义了学习率、优化器和损失函数。
-
在每个epoch中,对训练集进行前向传播、计算损失、反向传播和优化。
-
在验证集上评估模型性能,计算准确率。
if __name__ == "__main__":data_file = "/Users/azen/Desktop/llm/LLM-FullTime/dataset/text-classification/toutiao-text/toutiao_cat_data.txt"model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"labels = [100, 101, 102, 103, 104, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116]LABELS_MAP = {k:v for v, k in enumerate(labels)}sampled_titles, sampled_labels = read_data(data_file, n=200)max_length = len(max(sampled_titles, key=len))print("Max length of sequence: {}".format(max_length))tokenizer = BertTokenizer.from_pretrained(model_path)seq_len = 50dataset = Toutiao(sampled_titles, sampled_labels, tokenizer, seq_len)train_size = int(0.9*len(dataset))valid_size = len(dataset) - train_sizetrainset, validset = random_split(dataset, [train_size, valid_size])trainloader = DataLoader(trainset, batch_size=16, shuffle=True)validloader = DataLoader(validset, batch_size=8, shuffle=False)vocab_size = tokenizer.vocab_sized_model = 256dff = 4*d_modeldropout = 0.1num_heads = 8N_block = 2classes = len(set(labels))model = BertClassification(vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block, classes)lr = 1e-3optim = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.CrossEntropyLoss()epochs = 20for epoch in range(epochs):total_train_loss = 0model.train()for batch in tqdm(trainloader, desc="Training"):batch_title_txt = batch["title_txt"]batch_title_ids = batch["title_ids"]batch_input_msk = batch["input_msk"]batch_labels = batch["label"]output = model.forward(batch_title_ids, batch_input_msk)train_loss = loss_fn(output, batch_labels)train_loss.backward()optim.step()optim.zero_grad()total_train_loss += train_lossmodel.eval()total_valid_loss = 0correct = 0with torch.no_grad():for batch in tqdm(validloader, desc="Validating"):batch_title_txt = batch["title_txt"]batch_title_ids = batch["title_ids"]batch_input_msk = batch["input_msk"]batch_labels = batch["label"]output = model.forward(batch_title_ids, batch_input_msk)correct += torch.sum(torch.argmax(output, dim=-1) == batch_labels).item()valid_loss = loss_fn(output, batch_labels)acc = correct / len(validloader.dataset)print("Epoch: {}, Train Loss: {}, Valid Loss: {}, Acc: {:.2f}%".format(epoch, total_train_loss/len(trainloader), valid_loss, acc*100))pass