当前位置: 首页 > news >正文

BERT - MLM 和 NSP

本节代码将实现BERT模型的两个主要预训练任务:掩码语言模型(Masked Language Model, MLM)下一句预测(Next Sentence Prediction, NSP)

1. create_nsp_dataset 函数

这个函数用于生成NSP任务的数据集。

def create_nsp_dataset(corpus):
    nsp_dataset = []
    for i in range(len(corpus)-1):
        next_sentence = corpus[i+1]

        rand_id = random.randint(0, len(corpus) - 1)
        while abs(rand_id - i) <= 1:
            rand_id = random.randint(0, len(corpus) - 1)
        
        negt_sentence = corpus[rand_id]
        nsp_dataset.append((corpus[i], next_sentence, 1))  # 正样本
        nsp_dataset.append((corpus[i], negt_sentence, 0))  # 负样本

    return nsp_dataset
  • 正样本corpus[i]corpus[i+1] 是连续的句子对,标记为 1,表示它们是相邻的句子。

  • 负样本corpus[i] 和随机选择的句子 corpus[rand_id] 组成一个句子对,标记为 0,表示它们不是相邻的句子。

  • 随机选择负样本:通过随机选择句子来生成负样本,确保模型能够学习区分相邻句子和非相邻句子。

2. BERTDataset 类

这个类继承自 torch.utils.data.Dataset,用于加载和处理BERT预训练任务的数据。

def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):
    self.nsp_dataset = nsp_dataset
    self.tokenizer = tokenizer
    self.max_length = max_length

    self.cls_id = tokenizer.cls_token_id
    self.sep_id = tokenizer.sep_token_id
    self.pad_id = tokenizer.pad_token_id
    self.mask_id = tokenizer.mask_token_id
  • nsp_dataset:存储NSP任务的数据集,每个样本是一个三元组 (sent1, sent2, nsp_label)

  • tokenizer:用于将文本转换为词索引(token IDs)。

  • max_length:序列的最大长度,用于填充或截断。

  • 特殊标记

    • self.cls_id[CLS] 标记的索引。

    • self.sep_id[SEP] 标记的索引。

    • self.pad_id[PAD] 标记的索引。

    • self.mask_id[MASK] 标记的索引。

__len__ 方法
def __len__(self):
    return len(self.nsp_dataset)
  • 返回数据集的大小,即样本数量。

__getitem__ 方法
def __getitem__(self, idx):
    sent1, sent2, nsp_label = self.nsp_dataset[idx]

    sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)
    sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)

    tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]
    seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)
    
    mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)

    mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)
    seg_ids = self.pad_to_seq_len(seg_ids, 2)
    mlm_labels = self.pad_to_seq_len(mlm_labels, -100)

    mask = (mlm_tok_ids != 0)

    return {
        "mlm_tok_ids": mlm_tok_ids,
        "seg_ids": seg_ids,
        "mask": mask,
        "mlm_labels": mlm_labels,
        "nsp_labels": torch.tensor(nsp_label)
    }
  • 句子编码

    • sent1_idssent2_ids 分别是两个句子的词索引列表。

    • 使用 self.tokenizer.encode 将句子转换为词索引,add_special_tokens=False 表示不添加特殊标记([CLS][SEP])。

  • 构建输入序列

    • tok_ids:将两个句子的词索引列表组合成一个序列,中间用 [SEP] 分隔,并在开头添加 [CLS]

    • seg_ids:段嵌入索引,第一个句子使用 0,第二个句子使用 1

  • MLM任务

    • mlm_tok_idsmlm_labels 是通过 build_mlm_dataset 方法生成的,用于MLM任务。

  • 填充和截断

    • 使用 pad_to_seq_len 方法将 mlm_tok_idsseg_idsmlm_labels 填充或截断到 max_length

  • 掩码

    • mask:生成一个掩码,用于标记哪些位置是有效的输入(非填充部分)。

pad_to_seq_len 方法
def pad_to_seq_len(self, seq, pad_value):
    seq = seq[:self.max_length]
    pad_num = self.max_length - len(seq)
    return torch.tensor(seq + pad_num * [pad_value])
设计原因
  • 将序列截断到 max_length,并用 pad_value 填充到 max_length

build_mlm_dataset 方法
def build_mlm_dataset(self, tok_ids):
    mlm_tok_ids = tok_ids.copy()
    mlm_labels = [-100] * len(tok_ids)

    for i in range(len(tok_ids)):
        if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:
            if random.random() < 0.15:
                mlm_labels[i] = tok_ids[i]

                if random.random() < 0.8:
                    mlm_tok_ids[i] = self.mask_id
                elif random.random() < 0.9:
                    mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)
    return mlm_tok_ids, mlm_labels
  • MLM任务

    • 随机选择一些词(概率为15%),并将它们替换为 [MASK](80%)、随机词(10%)或保持不变(10%)。

    • mlm_labels 用于存储被替换词的真实索引,未被替换的位置标记为 -100(PyTorch中忽略计算损失的标记)。

Bert完整代码(标红部分为本节所提到部分)

import re
import math
import torch
import random
import torch.nn as nn


from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader

# nn.TransformerEncoderLayer


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        atten_scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)
            atten_scores = atten_scores.masked_fill(mask == 0, -1e9)

        atten_scores = torch.softmax(atten_scores, dim=-1)
        out = atten_scores @ V
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        return self.dropout(self.o_proj(out))


class FeedForward(nn.Module):
    def __init__(self, d_model, dff, dropout):
        super().__init__()
        self.W1 = nn.Linear(d_model, dff)
        self.act = nn.GELU()
        self.W2 = nn.Linear(dff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.W2(self.dropout(self.act(self.W1(x))))


class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout, dff):
        super().__init__()
        self.mha_block = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn_block = FeedForward(d_model, dff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        res1 = self.norm1(x + self.dropout1(self.mha_block(x, mask)))
        res2 = self.norm2(res1 + self.dropout2(self.ffn_block(res1)))
        return res2
    
class BertModel(nn.Module):
    def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.seg_emb = nn.Embedding(3, d_model)
        self.pos_emb = nn.Embedding(seq_len, d_model)

        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, dropout, dff)
            for _ in range(N_blocks)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, seg_ids, mask):
        pos = torch.arange(x.shape[1])

        tok_emb = self.tok_emb(x)
        seg_emb = self.seg_emb(seg_ids)
        pos_emb = self.pos_emb(pos)

        x = tok_emb + seg_emb + pos_emb
        
        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        return x
    
class BERT(nn.Module):
    def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):
        super().__init__()
        self.bert = BertModel(vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff)
        
        self.mlm_head = nn.Linear(d_model, vocab_size)
        self.nsp_head = nn.Linear(d_model, 2)

    def forward(self, mlm_tok_ids, seg_ids, mask):
        bert_out = self.bert(mlm_tok_ids, seg_ids, mask)
        cls_token = bert_out[:, 0, :]
        mlm_logits = self.mlm_head(bert_out)
        nsp_logits = self.nsp_head(cls_token)
        return mlm_logits, nsp_logits

def read_data(file):

    with open(file, "r", encoding="utf-8") as f:
        data = f.read().strip().replace("\n", "")
    corpus = re.split(r'[。,“”:;!、]', data)
    corpus = [sentence for sentence in corpus if sentence.strip()]
    return corpus


def create_nsp_dataset(corpus):

    nsp_dataset = []
    for i in range(len(corpus)-1):
        next_sentence = corpus[i+1]

        rand_id = random.randint(0, len(corpus) - 1)
        while abs(rand_id - i) <= 1:
            rand_id = random.randint(0, len(corpus) - 1)
        
        negt_sentence = corpus[rand_id]
        nsp_dataset.append((corpus[i], next_sentence, 1)) # 正样本
        nsp_dataset.append((corpus[i], negt_sentence, 0)) # 负样本

    return nsp_dataset


class BERTDataset(Dataset):
    def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):
        self.nsp_dataset = nsp_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.cls_id = tokenizer.cls_token_id
        self.sep_id = tokenizer.sep_token_id
        self.pad_id = tokenizer.pad_token_id
        self.mask_id = tokenizer.mask_token_id

    def __len__(self):
        return len(self.nsp_dataset)

    def __getitem__(self, idx):
        sent1, sent2, nsp_label = self.nsp_dataset[idx]

        sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)
        sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)

        tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]
        seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)
        
        mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)

        mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)
        seg_ids = self.pad_to_seq_len(seg_ids, 2)
        mlm_labels = self.pad_to_seq_len(mlm_labels, -100)

        mask = (mlm_tok_ids != 0)

        return {
            "mlm_tok_ids": mlm_tok_ids,
            "seg_ids": seg_ids,
            "mask": mask,
            "mlm_labels": mlm_labels,
            "nsp_labels": torch.tensor(nsp_label)
        }
    
    def pad_to_seq_len(self, seq, pad_value):
        seq = seq[:self.max_length]
        pad_num = self.max_length - len(seq)
        return torch.tensor(seq + pad_num * [pad_value])
    
    def build_mlm_dataset(self, tok_ids):
        mlm_tok_ids = tok_ids.copy()
        mlm_labels = [-100] * len(tok_ids)

        for i in range(len(tok_ids)):
            if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:
                if random.random() < 0.15:
                    mlm_labels[i] = tok_ids[i]

                    if random.random() < 0.8:
                        mlm_tok_ids[i] = self.mask_id
                    elif random.random() < 0.9:
                        mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)
        return mlm_tok_ids, mlm_labels



if __name__ == "__main__":

    data_file = "4.10-BERT/背影.txt"
    model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"
    tokenizer = BertTokenizer.from_pretrained(model_path)

    corpus = read_data(data_file)
    max_length = 25 # len(max(corpus, key=len))
    print("Max length of dataset: {}".format(max_length))
    nsp_dataset = create_nsp_dataset(corpus)

    trainset = BERTDataset(nsp_dataset, tokenizer, max_length)
    batch_size = 16
    trainloader = DataLoader(trainset, batch_size, shuffle=True)

    vocab_size = tokenizer.vocab_size
    d_model = 768
    N_blocks = 2
    num_heads = 12
    dropout = 0.1
    dff = 4*d_model
    model = BERT(vocab_size, d_model, max_length, N_blocks, num_heads, dropout, dff)
    
    lr = 1e-3
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    loss_fn = nn.CrossEntropyLoss()
    epochs = 20

    for epoch in range(epochs):
        for batch in trainloader:
            batch_mlm_tok_ids = batch["mlm_tok_ids"]
            batch_seg_ids = batch["seg_ids"]
            batch_mask = batch["mask"]
            batch_mlm_labels = batch["mlm_labels"]
            batch_nsp_labels = batch["nsp_labels"]

            mlm_logits, nsp_logits = model(batch_mlm_tok_ids, batch_seg_ids, batch_mask)

            loss_mlm = loss_fn(mlm_logits.view(-1, vocab_size), batch_mlm_labels.view(-1))
            loss_nsp = loss_fn(nsp_logits, batch_nsp_labels)

            loss = loss_mlm + loss_nsp
            loss.backward()
            optim.step()
            optim.zero_grad()

        print("Epoch: {}, MLM Loss: {}, NSP Loss: {}".format(epoch, loss_mlm, loss_nsp))
    
        pass
    pass

相关文章:

  • 2025最新数字化转型国家标准《数字化转型管理参考架构》 正式发布
  • 蓝桥杯python组备考3(b站课程笔记)超详细
  • Mac学习使用全借鉴模式
  • Java实现音频录音播放机功能
  • 基于yolov11的鱼新鲜度检测系统python源码+pytorch模型+评估指标曲线+精美GUI界面
  • 小白学习java第12天(中):IO流之字节输入输出流
  • 微服务无感发布实践:基于Nacos的客户端缓存与故障转移机制
  • C#网络编程(Socket编程)
  • 镜舟科技亮相 2025 中国移动云智算大会,展示数据湖仓一体创新方案
  • 面试之《websocket》
  • BusyBox 与 Toybox:嵌入式 Linux 的轻量工具集对比与解析
  • OCR API识别对比
  • AI比人脑更强,因为被植入思维模型【54】混沌与秩序思维模型
  • 浅层神经网络:从数学原理到实战应用的全面解析
  • 【C++初学】C++核心编程(一):内存管理和引用
  • 2025.4.9 华为机考 第1题-补丁版本升级
  • 学术分享:基于 ARCADE 数据集评估 Grounding DINO、YOLO 和 DINO 在血管狭窄检测中的效果
  • 机器学习十大算法全解析机器学习,作为人工智能的基石,涵盖了众多高效的算法。今天,我们就来深入探讨其中的十大核心算法!
  • C#的反射机制
  • vue3循环表单【以el-form组件为例】,如何校验所有表单,所有表单校验通过后提交
  • 西安网站自然排名优化/腾讯企业qq
  • 怎么建设一个外国网站/抖音seo是什么
  • 网站设计公司天津/一个新产品策划方案
  • 建站之星极速版/广告公司推广方案
  • asp网站首页模板/抖音seo推广
  • 品牌的佛山网站建设/semantic scholar