当前位置: 首页 > news >正文

【BERT_Pretrain】Wikipedia_Bookcorpus数据预处理(二)

上一篇介绍了wikipedia和bookcopus数据集,这一篇主要讲一下如何预处理数据,使其可以用于BERT的Pretrain任务MLM和NSP。

MLM是类似于完形填空的任务,NSP是判断两个句子是否连着。因此数据预处理的方式不同。首先,拿到原始数据集,两个数据集都是段落,因此要分成单句。然后,有了单句针对不同任务进行预处理。BERT原文的原始做法是。将数据集按照NSP任务预处理,再进行mask,得到MLM任务的数据。

段落变单句

利用nltk包分开句子。(代码接上一篇)

# pargraph->sentences
import nltk
from nltk.tokenize import sent_tokenizetry:sent_tokenize("Test sentence.")  # 尝试使用以触发错误
except LookupError:nltk.download('punkt')  # 自动下载所需资源nltk.download('punkt_tab')#'''
#test
#'''
text = "This is the first sentence. This is the second sentence."# 直接使用
sentences = sent_tokenize(text)
print(sentences)
print('success!')

将全部数据段落分成句子。

def preprocess_text(text):sentences = sent_tokenize(text)sentences = [s.strip() for s in sentences if len(s.strip()) > 10] #只有去除空白后长度超过 10 的句子才会被保留return sentencesdef preprocess_examples(examples):processed_texts = []for text in examples["text"]:processed_texts.extend(preprocess_text(text))return {"sentences": processed_texts}processed_dataset = dataset.map(preprocess_examples,batched=True,	#批量处理remove_columns=dataset.column_names,	#相当于keynum_proc=4  #CPU并行处理
)

打印出一些处理后数据的信息。

print(type(processed_dataset))
print(processed_dataset.num_rows)
print(processed_dataset.column_names)
print(processed_dataset.shape)

NSP+MLM数据预处理

from transformers import BertTokenizer
import random# 加载分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')def create_nsp_mlm_examples(examples):input_ids = []token_type_ids = []attention_mask = []next_sentence_label = []masked_lm_labels = []sentences = examples["sentences"]#NSPfor i in range(len(sentences) - 1):if random.random() > 0.5:text_a = sentences[i]text_b = sentences[i + 1]label = 1else:text_a = sentences[i]text_b = random.choice(sentences)while text_b == sentences[i + 1]:  # 防止随机到了真实下一句text_b = random.choice(sentences)label = 0# 用分词器将两个句子拼接encoded = tokenizer(text_a,text_b,max_length=512,truncation=True,padding=False, # ‘max_length’	# 这里选的不带padding,可以减少一部分内存占用return_tensors='pt')#MLMinput_id_list = encoded['input_ids'][0].tolist()mlm_labels = [0] * len(input_id_list)# 由于要添加mask,先指定没有添加mask的时候是0(huggingface这里设置的是-100)special_token_ids = {tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id}candidate_indices = [i for i, t_id in enumerate(input_id_list) if t_id not in special_token_ids]num_to_mask = max(1, int(len(candidate_indices) * 0.15))#句子中15%的随机mask_indices = random.sample(candidate_indices, min(512, num_to_mask))for idx in mask_indices:original_token = input_id_list[idx]mlm_labels[idx] = original_tokenprob = random.random()if prob < 0.8:#80%为maskinput_id_list[idx] = tokenizer.mask_token_idelif 0.8 <= prob < 0.9:#10%为随机input_id_list[idx] = random.randint(0, tokenizer.vocab_size - 1)# 10%保留原 token,不变input_ids.append(input_id_list)token_type_ids.append(encoded['token_type_ids'][0].tolist())attention_mask.append(encoded['attention_mask'][0].tolist())next_sentence_label.append(label)masked_lm_labels.append(mlm_labels)return {"input_ids": input_ids,"token_type_ids": token_type_ids,"attention_mask": attention_mask,"next_sentence_labels": next_sentence_label,"masked_lm_labels": masked_lm_labels}
# 将整个数据集处理一遍(耗时很长)
final_dataset = processed_dataset.map(create_nsp_mlm_examples,batched=True,remove_columns=processed_dataset.column_names,num_proc=4
)
#保存
final_dataset.save_to_disk("data_processed_nsp_mlm4gb", max_shard_size="4GB")

我这里的逻辑是,先预处理数据,再保存预处理数据直接用于训练,但是也可以一边训练一边预处理数据(尤其是你的磁盘大小不够保存额外数据集的情况)。

注意,上文padding由于选择了False,因此在用Dataloader包裹这个数据会遇到问题,因为torch.stack不能连接形状不一样的tensor,因此Dataloader里面的collate_fn还需要重写。(下一章介绍)

http://www.dtcms.com/a/264803.html

相关文章:

  • Electron 快速上手
  • vscode vim插件示例json意义
  • C++ 第四阶段 文件IO - 第一节:ifstream/ofstream操作
  • JavaScript---查询数组符合条件的元素
  • 解决 npm install canvas@2.11.2 失败的问题
  • 【公司环境下发布个人NPM包完整教程】
  • SPI、I2C和UART三种串行通信协议的--------简单总结
  • NLP:文本张量表示方法
  • 【安全工具】SQLMap 使用详解:从基础到高级技巧
  • 【字节跳动】数据挖掘面试题0001:打车场景下POI与ODR空间关联查询
  • C++实现状态机
  • 20250703|Leetcodehot100之739【】今天计划
  • Linux环境下使用 C++ 与 OpenCV 实现 ONNX 分类模型推理
  • 洛谷P2119 [NOIP 2016 普及组] 魔法阵【题解】【前缀和优化】
  • Java 大视界 -- Java 大数据在智能医疗健康管理中的慢性病风险预测与个性化干预(330)
  • Javaee 多线程 --进程和线程之间的区别和联系
  • nvm:NodeJs版本管理工具下载安装与使用教程
  • macOS挂载iOS应用沙盒文件夹
  • 飞算 JavaAI 智控引擎:全链路开发自动化新图景
  • 【字节跳动】数据挖掘面试题0003:有一个文件,每一行是一个数字,如何用 MapReduce 进行排序和求每个用户每个页面停留时间
  • 橡胶硬度计在不同领域中的应用
  • mybatis考试
  • 无人机一机多控技术的核心要点
  • 亿级物联网MQTT集群:OpenResty深度优化实践
  • Docker for Windows 设置国内镜像源教程
  • 基于spark的航班价格分析预测及可视化
  • v3 中的storeToRefs
  • AWS WebRTC:根据viewer端拉流日志推算视频帧率和音频帧率
  • uniapp实现图片预览,懒加载
  • 数据分类分级系统的建设思路