当前位置: 首页 > wzjs >正文

学习建设网站需要多久常州网油卷介绍

学习建设网站需要多久,常州网油卷介绍,郑州主城区,网站开发源代码知识产权归属本节代码将展示如何在预训练的BERT模型基础上进行微调,以适应特定的下游任务。 ⭐学习建议直接看文章最后的需复现代码,不懂得地方再回看 微调是自然语言处理中常见的方法,通过在预训练模型的基础上添加额外的层,并在特定任务的…

本节代码将展示如何在预训练的BERT模型基础上进行微调,以适应特定的下游任务。

学习建议直接看文章最后的需复现代码,不懂得地方再回看

微调是自然语言处理中常见的方法,通过在预训练模型的基础上添加额外的层,并在特定任务的数据集上进行训练,可以快速适应新的任务。以下是从模型微调的角度对代码的详细说明:
 

1. 加载预训练模型

self.bert = BertModel.from_pretrained(model_path)
  • 预训练模型:使用 transformers 库的 BertModel.from_pretrained 方法加载一个预训练的BERT模型。model_path 是预训练模型的路径或名称,例如 "bert-base-chinese"

  • 优势

    • 预训练模型已经在大规模语料上进行了训练,学习了通用的语言表示。

    • 微调可以利用这些预训练的参数,快速适应新的任务,通常只需要较少的数据和训练时间。

2. 添加任务特定的头

self.mlm_head = nn.Linear(d_model, vocab_size)
self.nsp_head = nn.Linear(d_model, 2)
  • MLM头mlm_head 是一个线性层,用于预测被掩盖的单词。输入是BERT模型的输出,输出是词汇表大小的预测概率。

  • NSP头nsp_head 是一个线性层,用于预测两个句子是否相邻。输入是BERT模型的 [CLS] 标记的输出,输出是二分类的概率。

3. 前向传播

def forward(self, mlm_tok_ids, seg_ids, mask):bert_out = self.bert(mlm_tok_ids, seg_ids, mask)output = bert_out.last_hidden_statecls_token = output[:, 0, :]mlm_logits = self.mlm_head(output)nsp_logits = self.nsp_head(cls_token)return mlm_logits, nsp_logits
  • BERT模型的输出

    • bert_out.last_hidden_state:BERT模型的输出,形状为 (batch_size, seq_len, d_model)

    • [CLS] 标记的输出:output[:, 0, :],用于NSP任务。

  • 任务特定的输出

    • mlm_logits:MLM任务的预测结果。

    • nsp_logits:NSP任务的预测结果。

4. 数据处理

class BERTDataset(Dataset):def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):self.nsp_dataset = nsp_datasetself.tokenizer = tokenizerself.max_length = max_lengthself.cls_id = tokenizer.cls_token_idself.sep_id = tokenizer.sep_token_idself.pad_id = tokenizer.pad_token_idself.mask_id = tokenizer.mask_token_iddef __getitem__(self, idx):sent1, sent2, nsp_label = self.nsp_dataset[idx]sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)seg_ids = self.pad_to_seq_len(seg_ids, 2)mlm_labels = self.pad_to_seq_len(mlm_labels, -100)mask = (mlm_tok_ids != 0)return {"mlm_tok_ids": mlm_tok_ids,"seg_ids": seg_ids,"mask": torch.tensor(mask, dtype=torch.long),"mlm_labels": mlm_labels,"nsp_labels": torch.tensor(nsp_label)}
  • 数据处理

    • 将文本数据转换为词索引(tok_ids)。

    • 添加特殊标记([CLS][SEP])。

    • 生成段嵌入(seg_ids)。

    • 生成MLM任务的数据(mlm_tok_idsmlm_labels)。

    • 填充或截断序列到固定长度(max_length)。

  • 掩码:生成掩码,用于标记哪些位置是有效的输入(非填充部分)。

5. 训练过程

for epoch in range(epochs):for batch in tqdm(trainloader, desc="Training"):batch_mlm_tok_ids = batch["mlm_tok_ids"]batch_seg_ids = batch["seg_ids"]batch_mask = batch["mask"]batch_mlm_labels = batch["mlm_labels"]batch_nsp_labels = batch["nsp_labels"]mlm_logits, nsp_logits = model(batch_mlm_tok_ids, batch_seg_ids, batch_mask)loss_mlm = loss_fn(mlm_logits.view(-1, vocab_size), batch_mlm_labels.view(-1))loss_nsp = loss_fn(nsp_logits, batch_nsp_labels)loss = loss_mlm + loss_nsploss.backward()optim.step()optim.zero_grad()print("Epoch: {}, MLM Loss: {}, NSP Loss: {}".format(epoch, loss_mlm, loss_nsp))
  • 训练步骤

    • 前向传播:将输入数据通过模型,得到MLM和NSP任务的预测结果。

    • 计算损失:分别计算MLM和NSP任务的损失。

    • 反向传播:计算梯度并更新模型参数。

    • 优化器:使用Adam优化器,学习率设置为 1e-3

  • 进度条:使用 tqdm 显示训练进度,使训练过程更加直观。

6. 微调的优势

  • 快速适应新任务:预训练模型已经学习了通用的语言表示,微调可以快速适应新的任务,通常只需要较少的数据和训练时间。

  • 节省计算资源:从头训练BERT模型需要大量的计算资源和时间,而微调只需要在预训练模型的基础上进行少量的训练。

  • 更好的性能:预训练模型在大规模数据上进行了训练,通常具有更好的性能。微调可以进一步提升模型在特定任务上的表现。

需复现代码


import re
import math
import torch
import random
import torch.nn as nnfrom tqdm import tqdm
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoaderclass BERT(nn.Module):def __init__(self, vocab_size, d_model, seq_len, N_blocks, num_heads, dropout, dff):super().__init__()self.bert = BertModel.from_pretrained(model_path)self.mlm_head = nn.Linear(d_model, vocab_size)self.nsp_head = nn.Linear(d_model, 2)def forward(self, mlm_tok_ids, seg_ids, mask):bert_out = self.bert(mlm_tok_ids, seg_ids, mask)output = bert_out.last_hidden_statecls_token = output[:, 0, :]mlm_logits = self.mlm_head(output)nsp_logits = self.nsp_head(cls_token)return mlm_logits, nsp_logitsdef read_data(file):with open(file, "r", encoding="utf-8") as f:data = f.read().strip().replace("\n", "")corpus = re.split(r'[。,“”:;!、]', data)corpus = [sentence for sentence in corpus if sentence.strip()]return corpusdef create_nsp_dataset(corpus):nsp_dataset = []for i in range(len(corpus)-1):next_sentence = corpus[i+1]rand_id = random.randint(0, len(corpus) - 1)while abs(rand_id - i) <= 1:rand_id = random.randint(0, len(corpus) - 1)negt_sentence = corpus[rand_id]nsp_dataset.append((corpus[i], next_sentence, 1)) # 正样本nsp_dataset.append((corpus[i], negt_sentence, 0)) # 负样本return nsp_datasetclass BERTDataset(Dataset):def __init__(self, nsp_dataset, tokenizer: BertTokenizer, max_length):self.nsp_dataset = nsp_datasetself.tokenizer = tokenizerself.max_length = max_lengthself.cls_id = tokenizer.cls_token_idself.sep_id = tokenizer.sep_token_idself.pad_id = tokenizer.pad_token_idself.mask_id = tokenizer.mask_token_iddef __len__(self):return len(self.nsp_dataset)def __getitem__(self, idx):sent1, sent2, nsp_label = self.nsp_dataset[idx]sent1_ids = self.tokenizer.encode(sent1, add_special_tokens=False)sent2_ids = self.tokenizer.encode(sent2, add_special_tokens=False)tok_ids = [self.cls_id] + sent1_ids + [self.sep_id] + sent2_ids + [self.sep_id]seg_ids = [0]*(len(sent1_ids)+2) + [1]*(len(sent2_ids) + 1)mlm_tok_ids, mlm_labels = self.build_mlm_dataset(tok_ids)mlm_tok_ids = self.pad_to_seq_len(mlm_tok_ids, 0)seg_ids = self.pad_to_seq_len(seg_ids, 2)mlm_labels = self.pad_to_seq_len(mlm_labels, -100)mask = (mlm_tok_ids != 0)return {"mlm_tok_ids": mlm_tok_ids,"seg_ids": seg_ids,"mask": torch.tensor(mask, dtype=torch.long),"mlm_labels": mlm_labels,"nsp_labels": torch.tensor(nsp_label)}def pad_to_seq_len(self, seq, pad_value):seq = seq[:self.max_length]pad_num = self.max_length - len(seq)return torch.tensor(seq + pad_num * [pad_value], dtype=torch.long)def build_mlm_dataset(self, tok_ids):mlm_tok_ids = tok_ids.copy()mlm_labels = [-100] * len(tok_ids)for i in range(len(tok_ids)):if tok_ids[i] not in [self.cls_id, self.sep_id, self.pad_id]:if random.random() < 0.15:mlm_labels[i] = tok_ids[i]if random.random() < 0.8:mlm_tok_ids[i] = self.mask_idelif random.random() < 0.9:mlm_tok_ids[i] = random.randint(106, self.tokenizer.vocab_size - 1)return mlm_tok_ids, mlm_labelsif __name__ == "__main__":data_file = "4.10-BERT/背影.txt"model_path = "/Users/azen/Desktop/llm/models/bert-base-chinese"tokenizer = BertTokenizer.from_pretrained(model_path)corpus = read_data(data_file)max_length = 25 # len(max(corpus, key=len))print("Max length of dataset: {}".format(max_length))nsp_dataset = create_nsp_dataset(corpus)trainset = BERTDataset(nsp_dataset, tokenizer, max_length)batch_size = 16trainloader = DataLoader(trainset, batch_size, shuffle=True)vocab_size = tokenizer.vocab_sized_model = 768N_blocks = 2num_heads = 12dropout = 0.1dff = 4*d_modelmodel = BERT(vocab_size, d_model, max_length, N_blocks, num_heads, dropout, dff)lr = 1e-3optim = torch.optim.Adam(model.parameters(), lr=lr)loss_fn = nn.CrossEntropyLoss()epochs = 20for epoch in range(epochs):for batch in tqdm(trainloader, desc = "Training"):batch_mlm_tok_ids = batch["mlm_tok_ids"]batch_seg_ids = batch["seg_ids"]batch_mask = batch["mask"]batch_mlm_labels = batch["mlm_labels"]batch_nsp_labels = batch["nsp_labels"]mlm_logits, nsp_logits = model(batch_mlm_tok_ids, batch_seg_ids, batch_mask)loss_mlm = loss_fn(mlm_logits.view(-1, vocab_size), batch_mlm_labels.view(-1))loss_nsp = loss_fn(nsp_logits, batch_nsp_labels)loss = loss_mlm + loss_nsploss.backward()optim.step()optim.zero_grad()print("Epoch: {}, MLM Loss: {}, NSP Loss: {}".format(epoch, loss_mlm, loss_nsp))passpass


文章转载自:

http://or7nqICN.crxdn.cn
http://QGAv86rT.crxdn.cn
http://DytDaPbA.crxdn.cn
http://3hEtk3Vr.crxdn.cn
http://EtlWnS3B.crxdn.cn
http://HVS5I91J.crxdn.cn
http://V8ga5w2l.crxdn.cn
http://dyR5iyXE.crxdn.cn
http://QTgdjtsU.crxdn.cn
http://wLS7uyth.crxdn.cn
http://xaT76wTu.crxdn.cn
http://j3yDzR3c.crxdn.cn
http://UwSe74DA.crxdn.cn
http://S54AdTqg.crxdn.cn
http://fQht7poo.crxdn.cn
http://7gzQVasA.crxdn.cn
http://9BRFVSWn.crxdn.cn
http://vp9YWSfE.crxdn.cn
http://teRXyAWp.crxdn.cn
http://pfU1GqUq.crxdn.cn
http://IZ5oHtmv.crxdn.cn
http://q6fNOOoC.crxdn.cn
http://RP71HMj6.crxdn.cn
http://YDHrqre3.crxdn.cn
http://SQ3Ec5Xd.crxdn.cn
http://NtNZaTRj.crxdn.cn
http://PxdLQSod.crxdn.cn
http://TeSIx4U9.crxdn.cn
http://dba2XYt5.crxdn.cn
http://0M4Ou6FV.crxdn.cn
http://www.dtcms.com/wzjs/629334.html

相关文章:

  • 网站备案的要求肖云路那有做网站公司
  • 网站开发技术概述用asp做的网站有哪些
  • 免费化妆品网站模板下载重庆专业网站推广方案
  • 上海襄阳网站建设九龙坡区发布
  • 网站怎么加链接想开发一个旧物交易网站应该怎么做
  • 济宁企业网站建设嘉兴网站建设企业网站制作
  • 网站前台用什么做北京网站建设推广服务信息
  • 织梦如何做几种语言的网站高效的网站建设
  • 成都网站制作方案汝州网站建设汝州
  • 微信商城与网站一体门户网站界面设计模板下载
  • 宁波网站关键词排名提升辽宁省建设厅网站
  • 一些js特效的网站推荐安徽省城乡和住房建设厅网站
  • 免费做网站软件下载青岛推广软件
  • 网站网站建设专业建筑人才网筑才网
  • 哈尔滨建站模板商业网站的相关内容
  • 网站集群建设方案怎么创建小程序商店
  • 评论凡科网站建设怎么样南京网站设计哪家公司好
  • 太原注册公司在哪个网站申请wordpress添加视频解析
  • 维护网站是什么意思网站开发培训价格
  • 网络公司经营范围网站建设擼擼擼做最好的导航网站
  • 阿克苏网站建设福彩网网站建设方案
  • 中国企业500强公司排名单页网站seo怎么做
  • 鄞州区住房和城乡建设局网站南昌网站设计企业
  • 零陵网站建设wordpress用户个人资料
  • 微信网站怎么做下载附件物流软件开发工具
  • 建设营销型网站流程图国家机构网站建设
  • 海南网站建设案例织梦cms一键更新网站无法使用
  • 电子商务网站建设财务预算网站被黑 原因
  • 企业门户网站建设机构学院网站建设项目概述
  • 网站制作比较好的制作公司网站制作定制18