当前位置: 首页 > news >正文

transformers基础Data Collator

Data Collator

  • 一、基本介绍
  • 二、 transformers.default_data_collator
  • 三、 DefaultDataCollator
  • 四、 DataCollatorWithPadding
  • 五、 DataCollatorForTokenClassification
  • 六、 DataCollatorForSeq2Seq
  • 七、 DataCollatorForLanguageModeling
  • 八、 总结

一、基本介绍

DataCollator是huggingface提供的transformers库中的数据整理器,主要用来将数据集中的数据处理成batch的形式。

二、 transformers.default_data_collator

default_data_collator是transformers中最基础且通用的数据整理器,他是一个函数,它知识简单的将数据打包成batch,不会进行任何的padding和truncation。

features = [{"input_ids": [101, 2345, 6789, 102], "label": 1},{"input_ids": [101, 3456, 7890, 4321, 102], "label": 0}
]

像这样的数据输入到default_data_collator中就会报错,因为input_ids的长度不一致,而default_data_collator又不会进行padding,就会报错。

使用default_data_collator的情况是所有样本都已经被padding到相同长度,不需要进行额外处理(如masking)等,仅仅需要打包成batch。

导入包

from transformers import default_data_collator

参数

default_data_collator(features: list, return_tensors: str = "pt")
  • features:包含多个样本的列表,每一个样本是一个dict
  • return_tensors:是返回数据的格式,pt代表返回pytorch的tensor张量。

三、 DefaultDataCollator

DefaultDataCollator是类形式的的数据整理器,它和上面default_data_collator的功能相同,但是是以类的形式实现的,它也不会进行padding处理,需要输入的数据长度是一样的。

由于不会进行默认padding,所以基本很少使用。

导入包

from transformer import DefaultDataCollator

参数

DefaultDataCollator(return_tensors='pt')

使用

features = [{"input_ids": [101, 2345, 6789, 2345, 102], "label": 1},{"input_ids": [101, 3456, 7890, 4321, 102], "label": 0}
]
from transformers import DefaultDataCollatorcollator = DefaultDataCollator(return_tensors='pt')  # 指定返回 PyTorch 张量
batch = collator(features)print(batch)

输出

{'input_ids': tensor([[101, 2345, 6789, 2345, 102],[101, 3456, 7890, 4321, 102]]),'labels': tensor([1, 0])
}

四、 DataCollatorWithPadding

这是一个动态padding的数据整理器类,它会使一个batch中所有样本的长度相同。

class transformers.DataCollatorWithPadding(tokenizer: PreTrainedTokenizerBase,padding: Union[bool, str, PaddingStrategy] = True,max_length: Optional[int] = None,pad_to_multiple_of: Optional[int] = None,return_tensors: str = 'pt'
)

参数

  • tokenizer:传入一个tokenizer。
  • padding:有三种参数可以选择bool为True的话默认padding到batch中的最大长度,‘max_length’:填充到指定的 max_length,False不做padding。
  • max_length:如果 padding=‘max_length’,则所有序列都会被填充或截断到该长度。
    如果没有设置,则使用模型允许的最大输入长度。
  • pad_to_multiple_of:将序列长度向上填充为某个整数的倍数,比如 8、64 等。
  • return_tensors:“pt”:返回 PyTorch 张量、“tf”:返回 TensorFlow 张量、“np”:返回 NumPy 数组。

五、 DataCollatorForTokenClassification

适用于token级别的序列标注任务,如命名实体识别,继承了DataCollatorWithPadding的功能,同时能够对齐标签的长度。
参数
在这里插入图片描述

六、 DataCollatorForSeq2Seq

专门为Seq2Seq任务设置的数据整理器,适用于翻译、摘要生成等encoder-decoder等任务。

class transformers.DataCollatorForSeq2Seq(tokenizer: PreTrainedTokenizerBase,model: Optional[Any] = None,padding: Union[bool, str, PaddingStrategy] = True,max_length: Optional[int] = None,pad_to_multiple_of: Optional[int] = None,label_pad_token_id: int = -100,return_tensors: str = 'pt'
)

参数
在这里插入图片描述

features = [{"input_ids": [101, 2345, 6789, 102],"labels": [2, 3, 4, 5, 6]},{"input_ids": [101, 3456, 7890, 4321, 102],"labels": [3, 4, 5, 6, 7, 8]}
]
from transformers import DataCollatorForSeq2Seqcollator = DataCollatorForSeq2Seq(tokenizer=tokenizer,model=model,  # 可选,如果你用了像 BART 或 T5 这样的模型padding="longest",label_pad_token_id=-100,return_tensors="pt"
)batch = collator(features)
print(batch.keys())
# 输出:dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])

输出结果:

{'input_ids': tensor([[101, 2345, 6789, 102,   0],[101, 3456, 7890, 4321, 102]]),'attention_mask': tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]]),'labels': tensor([[-100, 2, 3, 4, 5],     # 第一个样本只有4个真实 label[3, 4, 5, 6, 7, 8]]),  # 第二个样本有6个 label'decoder_input_ids': tensor([[ 2,  3,  4,  5,  0],  # decoder_input_ids 自动构造[ 3,  4,  5,  6,  7]])
}

七、 DataCollatorForLanguageModeling

这个数据整理器适用于类似Bert的掩码语言模型。

class transformers.DataCollatorForLanguageModeling(tokenizer: PreTrainedTokenizerBase,mlm: bool = True,mlm_probability: float = 0.15,mask_replace_prob: float = 0.8,random_replace_prob: float = 0.1,pad_to_multiple_of: Optional[int] = None,tf_experimental_compile: bool = False,return_tensors: str = 'pt',seed: Optional[int] = None
)

在这里插入图片描述

八、 总结

transformers中的datacollator主要用于对数据进行padding、truncation并打包成一个batch。上面是常用的一些datacollator还有一些不太常用的datacollator可以参考[huggingface官方文档]。(https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding)

http://www.dtcms.com/a/283199.html

相关文章:

  • 教程:如何快速查询 A 股实时 K线和5档盘口
  • 今日行情明日机会——20250716
  • Redis深度解析:从缓存到分布式系统的核心引擎
  • 用python实现自动化布尔盲注
  • pytest--1--pytest-mock常用的方法
  • 代码随想录day36dp4
  • 震坤行获取商品SKU操作详解
  • 16路串口光纤通信FPGA项目实现指南
  • Kotlin获取集合中的元素操作
  • Java与Vue精心打造资产设备管理系统,提供源码,适配移动端与后台管理,助力企业高效掌控资产动态,提升管理效能
  • 【Java】JUC并发(synchronized进阶、ReentrantLock可重入锁)
  • 二重循环:输入行数,打印直角三角形和倒直角三角形
  • Java后端开发核心笔记:分层架构、注解与面向对象精髓
  • 基于Android的旅游计划App
  • Web基础 -MYSQL
  • 冷库耗电高的原因,冷链运营者的降本增效的方法
  • LVS四种模式及部署NAT、DR模式集群
  • CD53.【C++ Dev】模拟实现优先级队列(含仿函数)
  • 【计算机网络】数据通讯第二章 - 应用层
  • 深度学习之反向传播
  • 【迭代】PDF绘本录音播放,点读笔方案调研和初步尝试
  • leetcode 725 分割链表
  • 微算法科技研究量子视觉计算,利用量子力学原理提升传统计算机视觉任务的性能
  • Kafka入门
  • 语音增强论文汇总
  • Go基本数据类型
  • 81、面向服务开发方法
  • Redisson实现分布式锁
  • Redisson实现限流器详解:从原理到实践
  • HTML 入门教程:从零开始学习网页开发基础