BERT - 今日头条新闻分类任务实战
1. 自定义模型组件
MultiHeadAttention
类
-
实现了多头自注意力机制。
-
通过将输入分割成多个“头”,从不同角度学习输入数据的特征。
-
注意力分数计算后应用了缩放点积注意力,并支持掩码操作。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(1)
atten_scores = atten_scores.masked_fill(mask == 0, -1e9)
atten_scores = torch.softmax(atten_scores, dim=-1)
out = atten_scores @ V
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.dropout(self.o_proj(out))
FeedForward
类
-
实现了Transformer中的前馈网络(Feed-Forward Network, FFN)。
class FeedForward(nn.Module):
def __init__(self, d_model, dff, dropout):
super().__init__()
self.W1 = nn.Linear(d_model, dff)
self.act = nn.GELU()
self.W2 = nn.Linear(dff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.W2(self.dropout(self.act(self.W1(x))))
TransformerEncoderBlock
类
class TransformerEncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, dropout, dff):
super().__init__()
self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn_block = FeedForward(d_model, dff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))
res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))
return res2
2. BERT模型
BERT
类
-
实现了一个简化版的BERT模型,包括词嵌入、位置嵌入和多个Transformer编码器块。
-
没有实现段嵌入(Segment Embedding),因为文本分类任务通常不需要区分句子对。
class BERT(nn.Module):
def __init__(self, vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block):
super().__init__()
self.seq_len = seq_len
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(seq_len, d_model)
self.layers = nn.ModuleList([
TransformerEncoderBlock(d_model, num_heads, dropout, dff)
for _ in range(N_block)
])
def forward(self, x, mask):
position = torch.arange(0, self.seq_len)
tok_emb = self.tok_emb(x)
pos_emb = self.pos_emb(position)
x = tok_emb + pos_emb
for layer in self.layers:
x = layer(x, mask)
return x
BertClassification
类
-
在BERT模型的基础上添加了一个分类头(
self.cls
),用于文本分类任务。 -
分类头是一个线性层,将BERT模型的输出映射到类别数量。
class BertClassification(nn.Module):
def __init__(self, vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block, classes):
super().__init__()
self.bert = BERT(vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block)
self.cls = nn.Linear(d_model, classes)
def forward(self, x, mask):
bert_out = self.bert(x, mask) # (batch_size, seq_len, d_model)
cls_repr = bert_out[:, 0, :] #
return self.cls(cls_repr)
3. 数据处理
read_data
函数
-
读取数据文件,解析出标题和标签。
-
使用
collections.defaultdict
按类别组织数据。 -
随机采样每个类别的标题,确保数据平衡。
def read_data(file, n=200):
with open(file, "r", encoding="utf-8") as f:
data = f.read().strip().split("\n")
category_data = collections.defaultdict(list)
for line in data:
segs = line.split("_!_")
label = int(segs[1])
title = segs[3]
category_data[LABELS_MAP[label]].append(title)
sampled_titles = []
sampled_labels = []
for label, titles in category_data.items():
samples = random.sample(titles, min(n, len(titles)))
sampled_labels.extend([label]*len(samples))
sampled_titles.extend(samples)
return sampled_titles, sampled_labels
Toutiao
类
-
继承自
torch.utils.data.Dataset
,用于加载和处理数据。 -
使用
BertTokenizer
将文本转换为BERT模型可以处理的格式。 -
返回的每个样本包括标题文本、标题ID、注意力掩码和标签。
class Toutiao(Dataset):
def __init__(self, titles, labels, tokenizer: BertTokenizer, seq_len):
self.titles = titles
self.labels = labels
self.seq_len = seq_len
self.tokenzier = tokenizer
def __len__(self):
return len(self.titles)
def __getitem__(self, idx):
title_txt = self.titles[idx]
label = self.labels[idx]
tokenizer_out = self.tokenzier(
title_txt,
padding="max_length",
max_length=self.seq_len,
truncation=True
)
title_ids = tokenizer_out["input_ids"]
input_msk = tokenizer_out["attention_mask"]
return {
"title_txt": title_txt,
"title_ids": torch.tensor(title_ids),
"input_msk": torch.tensor(input_msk),
"label": torch.tensor(label),
}
4. 训练过程
数据加载
-
使用
random_split
将数据集分为训练集和验证集。 -
使用
DataLoader
加载数据,训练集使用随机打乱,验证集不打乱。
模型训练
-
定义了学习率、优化器和损失函数。
-
在每个epoch中,对训练集进行前向传播、计算损失、反向传播和优化。
-
在验证集上评估模型性能,计算准确率。
if __name__ == "__main__":
data_file = "/Users/azen/Desktop/llm/LLM-FullTime/dataset/text-classification/toutiao-text/toutiao_cat_data.txt"
model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"
labels = [100, 101, 102, 103, 104, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116]
LABELS_MAP = {k:v for v, k in enumerate(labels)}
sampled_titles, sampled_labels = read_data(data_file, n=200)
max_length = len(max(sampled_titles, key=len))
print("Max length of sequence: {}".format(max_length))
tokenizer = BertTokenizer.from_pretrained(model_path)
seq_len = 50
dataset = Toutiao(sampled_titles, sampled_labels, tokenizer, seq_len)
train_size = int(0.9*len(dataset))
valid_size = len(dataset) - train_size
trainset, validset = random_split(dataset, [train_size, valid_size])
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
validloader = DataLoader(validset, batch_size=8, shuffle=False)
vocab_size = tokenizer.vocab_size
d_model = 256
dff = 4*d_model
dropout = 0.1
num_heads = 8
N_block = 2
classes = len(set(labels))
model = BertClassification(vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block, classes)
lr = 1e-3
optim = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
epochs = 20
for epoch in range(epochs):
total_train_loss = 0
model.train()
for batch in tqdm(trainloader, desc="Training"):
batch_title_txt = batch["title_txt"]
batch_title_ids = batch["title_ids"]
batch_input_msk = batch["input_msk"]
batch_labels = batch["label"]
output = model.forward(batch_title_ids, batch_input_msk)
train_loss = loss_fn(output, batch_labels)
train_loss.backward()
optim.step()
optim.zero_grad()
total_train_loss += train_loss
model.eval()
total_valid_loss = 0
correct = 0
with torch.no_grad():
for batch in tqdm(validloader, desc="Validating"):
batch_title_txt = batch["title_txt"]
batch_title_ids = batch["title_ids"]
batch_input_msk = batch["input_msk"]
batch_labels = batch["label"]
output = model.forward(batch_title_ids, batch_input_msk)
correct += torch.sum(torch.argmax(output, dim=-1) == batch_labels).item()
valid_loss = loss_fn(output, batch_labels)
acc = correct / len(validloader.dataset)
print("Epoch: {}, Train Loss: {}, Valid Loss: {}, Acc: {:.2f}%".format(
epoch, total_train_loss/len(trainloader), valid_loss, acc*100
))
pass