当前位置: 首页 > news >正文

BERT - 今日头条新闻分类任务实战

1. 自定义模型组件 

MultiHeadAttention 类
  • 实现了多头自注意力机制。

  • 通过将输入分割成多个“头”,从不同角度学习输入数据的特征。

  • 注意力分数计算后应用了缩放点积注意力,并支持掩码操作。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)
            atten_scores = atten_scores.masked_fill(mask == 0, -1e9)

        atten_scores = torch.softmax(atten_scores, dim=-1)
        out = atten_scores @ V
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        return self.dropout(self.o_proj(out))
FeedForward 类
  • 实现了Transformer中的前馈网络(Feed-Forward Network, FFN)。

class FeedForward(nn.Module):
    def __init__(self, d_model, dff, dropout):
        super().__init__()
        self.W1 = nn.Linear(d_model, dff)
        self.act = nn.GELU()
        self.W2 = nn.Linear(dff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.W2(self.dropout(self.act(self.W1(x))))
TransformerEncoderBlock 类

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout, dff):
        super().__init__()
        self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn_block = FeedForward(d_model, dff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))
        res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))
        return res2

2. BERT模型

BERT 类
  • 实现了一个简化版的BERT模型,包括词嵌入、位置嵌入和多个Transformer编码器块。

  • 没有实现段嵌入(Segment Embedding),因为文本分类任务通常不需要区分句子对。

class BERT(nn.Module):
    def __init__(self, vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block):
        super().__init__()
        self.seq_len = seq_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(seq_len, d_model)

        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, dropout, dff)
            for _ in range(N_block)
        ])

    def forward(self, x, mask):
        
        position = torch.arange(0, self.seq_len)
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(position)
        x = tok_emb + pos_emb

        for layer in self.layers:
            x = layer(x, mask)

        return x
BertClassification 类
  • 在BERT模型的基础上添加了一个分类头(self.cls),用于文本分类任务。

  • 分类头是一个线性层,将BERT模型的输出映射到类别数量。

class BertClassification(nn.Module):
    def __init__(self, vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block, classes):
        super().__init__()
        self.bert = BERT(vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block)
        self.cls = nn.Linear(d_model, classes)


    def forward(self, x, mask):
        bert_out = self.bert(x, mask) # (batch_size, seq_len, d_model)
        cls_repr = bert_out[:, 0, :] # 
        return self.cls(cls_repr)

3. 数据处理

read_data 函数
  • 读取数据文件,解析出标题和标签。

  • 使用 collections.defaultdict 按类别组织数据。

  • 随机采样每个类别的标题,确保数据平衡。

def read_data(file, n=200):

    with open(file, "r", encoding="utf-8") as f:
        data = f.read().strip().split("\n")

    category_data = collections.defaultdict(list)

    for line in data:
        segs = line.split("_!_")
        label = int(segs[1])
        title = segs[3]

        category_data[LABELS_MAP[label]].append(title)

    sampled_titles = []
    sampled_labels = []
    for label, titles in category_data.items():
        samples = random.sample(titles, min(n, len(titles)))
        sampled_labels.extend([label]*len(samples))
        sampled_titles.extend(samples)
        
    return sampled_titles, sampled_labels
Toutiao 类
  • 继承自 torch.utils.data.Dataset,用于加载和处理数据。

  • 使用 BertTokenizer 将文本转换为BERT模型可以处理的格式。

  • 返回的每个样本包括标题文本、标题ID、注意力掩码和标签。

class Toutiao(Dataset):
    def __init__(self, titles, labels, tokenizer: BertTokenizer, seq_len):
        self.titles = titles
        self.labels = labels

        self.seq_len = seq_len
        self.tokenzier = tokenizer

    def __len__(self):
        return len(self.titles)
    
    def __getitem__(self, idx):
        title_txt = self.titles[idx]
        label = self.labels[idx]

        tokenizer_out = self.tokenzier(
            title_txt,
            padding="max_length",
            max_length=self.seq_len,
            truncation=True
        )

        title_ids = tokenizer_out["input_ids"]
        input_msk = tokenizer_out["attention_mask"]
        return {
            "title_txt": title_txt,
            "title_ids": torch.tensor(title_ids),
            "input_msk": torch.tensor(input_msk),
            "label": torch.tensor(label),
        }

4. 训练过程

数据加载
  • 使用 random_split 将数据集分为训练集和验证集。

  • 使用 DataLoader 加载数据,训练集使用随机打乱,验证集不打乱。

模型训练
  • 定义了学习率、优化器和损失函数。

  • 在每个epoch中,对训练集进行前向传播、计算损失、反向传播和优化。

  • 在验证集上评估模型性能,计算准确率。


if __name__ == "__main__":
    data_file = "/Users/azen/Desktop/llm/LLM-FullTime/dataset/text-classification/toutiao-text/toutiao_cat_data.txt"
    model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"
    labels = [100, 101, 102, 103, 104, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116]
    LABELS_MAP = {k:v for v, k in enumerate(labels)}
    sampled_titles, sampled_labels = read_data(data_file, n=200)

    max_length = len(max(sampled_titles, key=len))
    print("Max length of sequence: {}".format(max_length))

    tokenizer = BertTokenizer.from_pretrained(model_path)
    seq_len = 50
    dataset = Toutiao(sampled_titles, sampled_labels, tokenizer, seq_len)

    train_size = int(0.9*len(dataset))
    valid_size = len(dataset) - train_size
    trainset, validset = random_split(dataset, [train_size, valid_size])
    trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
    validloader = DataLoader(validset, batch_size=8, shuffle=False)

    vocab_size = tokenizer.vocab_size
    d_model = 256
    dff = 4*d_model
    dropout = 0.1
    num_heads = 8
    N_block = 2
    classes = len(set(labels))
    model = BertClassification(vocab_size, d_model, seq_len, dff, dropout, num_heads, N_block, classes)

    lr = 1e-3
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    loss_fn = nn.CrossEntropyLoss()

    epochs = 20

    for epoch in range(epochs):
        total_train_loss = 0
        model.train()
        for batch in tqdm(trainloader, desc="Training"):
            batch_title_txt = batch["title_txt"]
            batch_title_ids = batch["title_ids"]
            batch_input_msk = batch["input_msk"]
            batch_labels = batch["label"]

            output = model.forward(batch_title_ids, batch_input_msk)
            train_loss = loss_fn(output, batch_labels)
            train_loss.backward()
            optim.step()
            optim.zero_grad()
            total_train_loss += train_loss

        model.eval()
        total_valid_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in tqdm(validloader, desc="Validating"):
                batch_title_txt = batch["title_txt"]
                batch_title_ids = batch["title_ids"]
                batch_input_msk = batch["input_msk"]
                batch_labels = batch["label"]

                output = model.forward(batch_title_ids, batch_input_msk)
                correct += torch.sum(torch.argmax(output, dim=-1) == batch_labels).item()
                valid_loss = loss_fn(output, batch_labels)
            acc = correct / len(validloader.dataset)
        
        print("Epoch: {}, Train Loss: {}, Valid Loss: {}, Acc: {:.2f}%".format(
            epoch, total_train_loss/len(trainloader), valid_loss, acc*100
            ))

    pass

相关文章:

  • 软件测试岗位:IT行业中的质量守护者
  • AI预测3D新模型百十个定位预测+胆码预测+去和尾2025年4月11日第49弹
  • STM32+EC600E 4G模块 与华为云平台通信
  • 多因素认证
  • P1439 【模板】最长公共子序列
  • STM32 模块化开发指南 · 第 4 篇 用状态机管理 BLE 应用逻辑:分层解耦的实践方式
  • poi-tl
  • 全能格式转换器v16.3.0.159绿色便携版
  • 合并区间、插入区间~
  • 【LLM基础】Megatron-LM相关知识(主要是张量并行机制)
  • 无线通信网
  • leetcode 322. Coin Change
  • 谷歌25年春季新课:15小时速成机器学习
  • 【2025年认证杯数学中国数学建模网络挑战赛】C题 数据预处理与问题一二求解
  • 如何使用CAPL解析YAML文件?
  • Python爬虫第13节-解析库pyquery 的使用
  • C++ | 时间日期
  • WEB 前端学 JAVA(一)
  • Qwen2.5-7B-Instruct FastApi 部署调用教程
  • YOLO学习笔记 | YOLOv8 全流程训练步骤详解(2025年4月更新)
  • 优质的网站建设推广/百度推广优化公司
  • 常州公司做网站的流程/网络营销公司网络推广
  • 美食林商业供应链管理系统登录/电脑清理优化大师
  • 苏州住房与城乡建设部网站/双11各大电商平台销售数据
  • 做网站seo的步骤/怎么让某个关键词排名上去
  • wordpress仿站全套/搜索百度网址网页