TRL - Transformer Reinforcement Learning 传递给SFTTrainer的数据集
TRL - Transformer Reinforcement Learning 传递给SFTTrainer的数据集
flyfish
以json格式为例
步骤1:本地JSON文件
apple_data.json
内容如下
[{"instruction": "介绍苹果的基本分类(按成熟季节)","input": "苹果的季节分类","output": "苹果按成熟季节可分为三类:①早熟品种(7-8月成熟),如嘎啦、美八,口感偏酸脆;②中熟品种(9-10月成熟),如红富士、秦冠,甜度较高,耐储存;③晚熟品种(11月后成熟),如粉红女士,果肉紧实,风味浓郁,适合长期保存。"},{"instruction": "用拟人化手法描述一个苹果","input": "一个红苹果","output": "我是一个圆滚滚的红苹果,穿着半红半黄的“外套”,表面光滑得像打了蜡。头顶的果柄是我的“小辫子”,轻轻一摸,还能感受到皮肤下饱满的汁水在悄悄晃动。凑近闻,一股清甜的香气会钻进你的鼻子——那是我在邀请你咬一口呀!"},{"instruction": "解答关于苹果储存的问题","input": "如何延长苹果的保鲜期?","output": "延长苹果保鲜期的方法有:①冷藏储存:放入冰箱冷藏室(温度0-4℃),用保鲜袋包裹(留小孔透气),可保存1-2个月;②避免混放:与香蕉、梨等释放乙烯的水果分开,防止加速成熟;③阴凉通风处:常温下置于阴凉干燥处,可保存1-2周,适合短期食用。"},{"instruction": "列举苹果的3种常见食用方法(非直接生食)","input": "苹果的加工吃法","output": "苹果的非生食做法有:①苹果派:将苹果切丁,加黄油、糖炒软,与面粉、鸡蛋等制作的派皮烘焙,外酥内软;②苹果银耳汤:苹果切块与银耳、红枣同煮,加冰糖调味,润肺养颜;③烤苹果:苹果去核,填入蜂蜜和坚果,烤箱烤制,果香浓郁,适合秋冬食用。"},{"instruction": "对比苹果与梨的核心区别(从口感和营养侧重)","input": "苹果和梨的区别","output": "从口感看:苹果果肉多脆嫩(除面苹果外),果皮较薄;梨果肉偏细腻多汁,部分品种(如酥梨)果皮更光滑。从营养侧重看:苹果膳食纤维(尤其是果胶)含量更高,有助于肠道蠕动;梨的水分和梨醇含量更丰富,润肺生津效果更突出,适合干燥季节食用。"}
]
步骤2:用 load_dataset
加载本地JSON文件
使用 load_dataset
时,指定数据格式为 'json'
(因为是JSON文件),并通过 data_files
参数指定本地文件路径。
from datasets import load_dataset# 加载本地JSON文件
dataset = load_dataset(path='json', # 数据格式为JSONdata_files='apple_data.json' # 本地JSON文件路径(若不在当前目录,需写绝对路径,如'./data/apple_data.json')
)# 查看加载结果
print(dataset) # 打印数据集结构
print("\n第一条数据示例:")
print(dataset['train'][0]) # 打印第一条数据
输出结果
运行后会得到一个 DatasetDict
对象,默认拆分名为 'train'
(因为JSON文件中没有明确拆分,load_dataset
会默认将所有数据归为 'train'
拆分)。
Generating train split: 5 examples [00:00, 766.53 examples/s]
DatasetDict({train: Dataset({features: ['instruction', 'input', 'output'],num_rows: 5})
})第一条数据示例:
{'instruction': '介绍苹果的基本分类(按成熟季节)', 'input': '苹果的季节分类', 'output': '苹果按成熟季节可分为三类:①早熟品种(7-8月成熟),如嘎啦、美八,口感偏酸脆;②中熟品种(9-10月成熟),如红富士、秦冠,甜度较高,耐储存;③晚熟品种(11月后成熟),如粉红女士,果肉紧实,风味浓郁,适合长期保存。'}
说明
- 数据格式要求:数据是“JSON数组”(
[{}, {}, ...]
),这是load_dataset
支持的标准格式,每个对象对应一条样本。 - 拆分处理:如果数据有多个拆分(如train/test),可以将不同拆分存为多个JSON文件(如
train.json
、test.json
),然后通过data_files={'train': 'train.json', 'test': 'test.json'}
指定。 - 后续使用:加载后的
dataset['train']
可直接作为SFTTrainer
的训练数据(需确保字段匹配,比如SFTTrainer
可能需要'text'
字段,此时需预处理拼接instruction+input+output
为'text'
字段)。
这样就完成了本地JSON数据的加载,后续可根据需求进行预处理(如字段拼接、分词等),再用于模型训练。
一个例子看具体数据集的变化
import argparse
import pprintfrom datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMESfrom trl import (ModelConfig,ScriptArguments,SFTConfig,SFTTrainer,TrlParser,clone_chat_template,get_kbit_device_map,get_peft_config,get_quantization_config,
)def main(script_args, training_args, model_args):################# 模型与分词器初始化################quantization_config = get_quantization_config(model_args)model_kwargs = dict(revision=model_args.model_revision,trust_remote_code=model_args.trust_remote_code,attn_implementation=model_args.attn_implementation,torch_dtype=model_args.torch_dtype,use_cache=False if training_args.gradient_checkpointing else True,device_map=get_kbit_device_map() if quantization_config is not None else None,quantization_config=quantization_config,)# 创建模型config = AutoConfig.from_pretrained(model_args.model_name_or_path)valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):from transformers import AutoModelForImageTextToTextmodel_kwargs.pop("use_cache", None)model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)else:model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,** model_kwargs)# 创建分词器tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True)if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_token # 设置填充token# 设置聊天模板original_chat_template = tokenizer.chat_templateif tokenizer.chat_template is None:print("应用默认聊天模板...")model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")print("\n===== 聊天模板信息 =====")print(f"聊天模板: {tokenizer.chat_template[:200]}...") # 显示部分模板################# 数据集处理与跟踪################# 1. 加载原始数据集print("\n" + "="*50)print("阶段1: 加载原始数据集")dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)train_split = script_args.dataset_train_splitprint(f"数据集拆分: {list(dataset.keys())}")print(f"训练集样本数量: {len(dataset[train_split])}")print("原始样本示例:")pprint.pprint(dataset[train_split][0]) # 打印第一个原始样本print("原始样本字段: ", dataset[train_split].column_names)print("="*50)# 2. 解析对话内容print("\n阶段2: 解析对话内容")# 提取对话轮次信息sample = dataset[train_split][0]print(f"对话轮次数量: {sample['num_turns']}")print("对话角色序列: " + " → ".join([msg['role'] for msg in sample['messages']]))# 展示前2轮对话内容print("\n前2轮对话内容:")for i, msg in enumerate(sample['messages'][:2]):print(f"轮次{i+1} ({msg['role']}): {msg['content'][:100]}...")print("="*50)# 3. 应用聊天模板格式化print("\n阶段3: 应用聊天模板格式化对话")# 取前2个样本演示demo_samples = [dataset[train_split][i] for i in range(min(2, len(dataset[train_split])))]# 格式化对话formatted_texts = []for sample in demo_samples:# 使用tokenizer的聊天模板格式化多轮对话formatted = tokenizer.apply_chat_template(sample['messages'],tokenize=False,add_generation_prompt=False # 不添加生成提示,因为这是训练数据)formatted_texts.append(formatted)print(f"\n格式化后的对话样本 (前300字符):\n{formatted[:300]}...")print("="*50)# 4. Tokenization处理print("\n阶段4: 文本Tokenization")# 修正:将 max_seq_length 改为 max_length(SFTConfig 中正确的参数名)print(f"使用参数: max_length={training_args.max_length}, padding='max_length', truncation=True")def preprocess_function(examples):# 应用聊天模板格式化所有对话texts = [tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)for msgs in examples['messages']]# Tokenize处理:同样使用 max_lengthtokenized = tokenizer(texts,padding="max_length",truncation=True,max_length=training_args.max_length, # 修正参数名return_overflowing_tokens=False,)# 生成labels(填充部分标记为-100,不参与损失计算)tokenized["labels"] = [[label if mask == 1 else -100 for label, mask in zip(input_ids, attention_mask)]for input_ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"])]return tokenized# 处理少量样本用于演示demo_dataset = dataset[train_split].select(range(min(2, len(dataset[train_split]))))processed_dataset = demo_dataset.map(preprocess_function, batched=True, remove_columns=demo_dataset.column_names)# 展示处理后的数据格式print("\n处理后样本结构:")pprint.pprint(processed_dataset[0].keys()) # 显示字段: input_ids, attention_mask, labelsprint("\n处理后样本示例:")for i in range(len(processed_dataset)):print(f"\n样本{i+1}详情:")print(f"input_ids (前20个): {processed_dataset[i]['input_ids'][:20]}")print(f"attention_mask (前20个): {processed_dataset[i]['attention_mask'][:20]}")print(f"labels (前20个): {processed_dataset[i]['labels'][:20]}")print(f"序列总长度: {len(processed_dataset[i]['input_ids'])}") # 应等于 training_args.max_lengthprint("="*50)# 5. 转换为模型输入格式(张量)print("\n阶段5: 转换为模型输入格式")# 转换为PyTorch张量processed_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])print(f"转换后数据类型: {type(processed_dataset[0]['input_ids'])}")print(f"张量形状: {processed_dataset[0]['input_ids'].shape}") # 单样本形状: (max_length,)print("="*50)################# 训练过程################trainer = SFTTrainer(model=model,args=training_args,train_dataset=processed_dataset, # 使用处理后的数据集eval_dataset=dataset[script_args.dataset_test_split].select(range(min(2, len(dataset[script_args.dataset_test_split])))) if training_args.eval_strategy != "no" else None,processing_class=tokenizer,peft_config=get_peft_config(model_args),)print("\n开始训练(仅演示,实际训练可取消注释)")# trainer.train()# 保存模型# trainer.save_model(training_args.output_dir)# if training_args.push_to_hub:# trainer.push_to_hub(dataset_name=script_args.dataset_name)def make_parser(subparsers: argparse._SubParsersAction = None):dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)if subparsers is not None:parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)else:parser = TrlParser(dataclass_types)return parserif __name__ == "__main__":parser = make_parser()script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)main(script_args, training_args, model_args)
输入命令
python sft.py \--model_name_or_path Qwen/Qwen2-0.5B \--dataset_name trl-lib/Capybara \--learning_rate 2.0e-4 \--num_train_epochs 1 \--packing \--per_device_train_batch_size 2 \--gradient_accumulation_steps 8 \--gradient_checkpointing \--eos_token '<|im_end|>' \--eval_strategy steps \--eval_steps 100 \--use_peft \--lora_r 32 \--lora_alpha 16 \--output_dir Qwen2-0.5B-SFT
输出
===== 聊天模板信息 =====聊天模板: {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
You are a helpful assistant<|im_end|>
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
'...==================================================
阶段1: 加载原始数据集
数据集拆分: ['train', 'test']
训练集样本数量: 15806
原始样本示例:{'messages': [{'content': 'Recommend a movie to watch.\n', 'role': 'user'},{'content': 'I would recommend the movie, "The Shawshank ''Redemption" which is a classic drama film starring ''Tim Robbins and Morgan Freeman. This film tells a ''powerful story about hope and resilience, as it ''follows the story of a young man who is wrongfully ''convicted of murder and sent to prison. Amidst the ''harsh realities of prison life, the protagonist ''forms a bond with a fellow inmate, and together ''they navigate the challenges of incarceration, ''while holding on to the hope of eventual freedom. ''This timeless movie is a must-watch for its moving ''performances, uplifting message, and unforgettable ''storytelling.','role': 'assistant'},{'content': "Describe the character development of Tim Robbins' "'character in "The Shawshank Redemption".','role': 'user'},{'content': 'In "The Shawshank Redemption", Tim Robbins plays ''the character of Andy Dufresne, a banker who is ''wrongfully convicted of murdering his wife and her ''lover. Over the course of the film, we see a '"significant transformation in Andy's character.\n"'\n''At the beginning of the movie, Andy is a quiet, ''reserved man who maintains his innocence but is ''largely passive in the face of his unjust ''conviction. He seems almost aloof, detached from ''the harsh reality of his situation. However, ''beneath this seemingly impassive exterior, Andy is ''a man of deep intelligence and inner strength.\n''\n''As the story progresses, Andy begins to adapt to ''his new environment. He uses his banking skills to ''gain favor with the prison guards and the warden, ''providing them with financial advice and even ''helping them with tax evasion. This not only earns ''him certain privileges but also gives him a measure ''of protection within the prison walls.\n''\n''Despite the grim circumstances, Andy never loses ''hope. He befriends a fellow inmate, Red, and shares ''with him his dream of living a quiet life in ''Zihuatanejo, a small town on the Pacific coast of ''Mexico. This dream keeps him going and becomes a ''symbol of hope for other inmates as well.\n''\n'"Perhaps the most significant development in Andy's "'character comes towards the end of the film, when ''it is revealed that he has been digging a tunnel ''through his cell wall for the past 19 years. This ''act of defiance, carried out with nothing but a '"small rock hammer, is a testament to Andy's "'resilience and determination. It also shows his ''cunning, as he manages to keep his plan a secret ''from everyone, including Red.\n''\n'"In the end, Andy's character arc is one of hope, "'resilience, and redemption. Despite being placed in ''a situation that seems hopeless, he never gives up. ''He maintains his dignity, uses his intelligence to ''his advantage, and ultimately, manages to escape ''from Shawshank, proving his innocence in the ''process. His character serves as a powerful symbol '"of the human spirit's ability to endure and "'overcome even the most challenging circumstances.','role': 'assistant'},{'content': 'Explain the significance of the friendship between '"Andy and Red in shaping Andy's character "'development.','role': 'user'},{'content': 'The friendship between Andy Dufresne (Tim Robbins) ''and Ellis "Red" Redding (Morgan Freeman) is central ''to the narrative of "The Shawshank Redemption", and '"plays a significant role in shaping Andy's "'character development.\n''\n''When Andy first arrives at Shawshank, he is a ''quiet, introverted man, largely keeping to himself. ''Red, who is the long-term inmate, takes an interest '"in him and they slowly develop a friendship. Red's "'wisdom, street-smart attitude, and his ability to '"procure items become instrumental in Andy's "'survival and eventual escape.\n''\n'"Red is initially skeptical of Andy's innocence and "'his hopeful outlook on life. However, as their ''friendship grows, Red becomes more receptive to '"Andy's perspective. This friendship provides Andy "'with a confidant, a sounding board, and a ''supportive ally amidst the harsh realities of ''prison life. \n''\n'"Andy's influence on Red is equally profound. Andy's "'unyielding hope and resilience slowly chip away at '"Red's hardened cynicism. Andy shares his dreams of "'freedom and his plans for the future with Red, ''which initially seem unrealistic to Red, but over '"time, Andy's unwavering belief in hope begins to ""influence Red's outlook on life.\n"'\n''In many ways, their friendship is a beacon of hope ''and humanity in an otherwise oppressive '"environment. It's through this friendship that Andy "'finds the strength to maintain his dignity, ''persevere, and ultimately, to engineer his daring '"escape. It's also through this friendship that Red "'finds hope for redemption and a life beyond the ''prison walls.\n''\n''In conclusion, the friendship between Andy and Red '"is a pivotal element in shaping Andy's character ""development. It's through this bond that Andy finds "'the strength to endure his unjust imprisonment and ''to hold onto hope, ultimately leading to his ''redemption.','role': 'assistant'}],'num_turns': 6,'source': 'GPT4LLM'}原始样本字段: ['source', 'messages', 'num_turns']
==================================================
阶段2: 解析对话内容
对话轮次数量: 6
对话角色序列: user → assistant → user → assistant → user → assistant前2轮对话内容:
轮次1 (user): Recommend a movie to watch.
...
轮次2 (assistant): I would recommend the movie, "The Shawshank Redemption" which is a classic drama film starring Tim R...
==================================================
阶段3: 应用聊天模板格式化对话
格式化后的对话样本 (前300字符):
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Recommend a movie to watch.
<|im_end|>
<|im_start|>assistant
I would recommend the movie, "The Shawshank Redemption" which is a classic drama film starring Tim Robbins and Morgan Freeman. This film tells a powerful story about...格式化后的对话样本 (前300字符):
<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Determine the result obtained by evaluating 5338245-50629795848152. Numbers and symbols only, please.<|im_end|>
<|im_start|>assistant
5338245 - 50629795848152 = -50629790509907<|im_end|>
...
==================================================
阶段4: 文本Tokenization
使用参数: max_length=1024, padding='max_length', truncation=True
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 77.97 examples/s]处理后样本结构:
dict_keys(['input_ids', 'attention_mask', 'labels'])处理后样本示例:样本1详情:
input_ids (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 67644, 264, 5700, 311, 3736, 624, 151645]
attention_mask (前20个): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 67644, 264, 5700, 311, 3736, 624, 151645]
序列总长度: 1024样本2详情:
input_ids (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 35, 24308, 279, 1102, 12180, 553, 37563]
attention_mask (前20个): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels (前20个): [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 151645, 198, 151644, 872, 198, 35, 24308, 279, 1102, 12180, 553, 37563]
序列总长度: 1024
==================================================阶段5: 转换为模型输入格式
转换后数据类型: <class 'torch.Tensor'>
张量形状: torch.Size([1024])
==================================================
从人类可理解的交互文本,逐步转换为模型可计算的张量
阶段1:原始数据集加载——保留原始交互结构
- 输入:数据集文件通过
load_dataset
加载。 - 处理逻辑:从数据源读取原始数据,不做任何格式修改,保留所有字段和原始交互信息。
- 输出格式:
Dataset
对象,每个样本为包含多字段的字典,核心字段包括:messages
:列表,每个元素是单轮对话({"role": "user"/"assistant", "content": "对话内容"}
)。num_turns
:整数,表示对话总轮次(如6轮)。source
:字符串,标识数据来源(如"GPT4LLM")。
- 示例(数据):
{"source": "GPT4LLM","num_turns": 6,"messages": [{"role": "user", "content": "Recommend a movie to watch.\n"},{"role": "assistant", "content": "I would recommend the movie..."},# 更多轮次...] }
阶段2:对话内容解析——提取核心交互信息
- 输入:原始数据集样本(阶段1的输出)。
- 处理逻辑:从原始数据中提取关键交互信息,理解对话结构(为后续格式化做准备)。
- 输出内容:
- 对话轮次:通过
num_turns
获取(如6轮)。 - 角色序列:提取
messages
中role
的顺序(如user → assistant → user → assistant...
)。 - 原始对话内容:提取每轮
content
的文本(如用户的电影推荐请求、助手的回复等)。
- 对话轮次:通过
- 作用:明确数据的交互逻辑(用户提问→助手回答的交替),为后续统一格式奠定基础。
阶段3:聊天模板格式化——统一模型输入格式
- 输入:阶段2解析的多轮
messages
(角色+内容)。 - 处理逻辑:使用分词器的
apply_chat_template
方法,将零散的messages
转换为模型可识别的统一字符串格式。- 核心逻辑:按模型预定义的模板(如Qwen的模板),为每轮对话添加角色标记(如
<|user|>
、<|assistant|>
)和分隔符(如换行、特殊token),确保上下文连贯性。
- 核心逻辑:按模型预定义的模板(如Qwen的模板),为每轮对话添加角色标记(如
- 输出格式:单字符串,包含完整对话历史,带角色标记和结构分隔。
- 示例(基于数据):
"<|user|>Recommend a movie to watch.\n<|assistant|>I would recommend the movie, "The Shawshank Redemption"...<|user|>Describe the character development of Tim Robbins' character...<|assistant|>In "The Shawshank Redemption", Tim Robbins plays..."
- 作用:让模型通过统一的格式识别“谁在说话”,明确上下文的角色边界(否则模型无法区分用户和助手的话)。
阶段4:Tokenization处理——文本→数字序列
- 输入:阶段3生成的格式化对话字符串。
- 处理逻辑:使用分词器(
tokenizer
)将文本转换为模型能理解的数字序列,核心步骤包括:- 分词:将字符串拆分为模型词汇表中的“最小语义单位”(token),如“Shawshank”→
3456
。 - 长度统一:根据
training_args.max_length
(如512),对序列进行截断(超长)或填充(不足),填充使用pad_token
(通常为eos_token
)。 - 生成辅助字段:
input_ids
:token对应的数字ID序列(模型的核心输入)。attention_mask
:0/1序列(1表示有效token,0表示填充的无效token)。labels
:用于计算损失的序列(与input_ids
一致,但填充位置标记为-100
,避免模型学习填充内容)。
- 分词:将字符串拆分为模型词汇表中的“最小语义单位”(token),如“Shawshank”→
- 输出格式:字典,包含3个核心字段(均为列表):
{"input_ids": [101, 2345, 678, ..., 0, 0], # 长度=max_length,0为填充"attention_mask": [1, 1, 1, ..., 0, 0], # 有效token标记"labels": [101, 2345, 678, ..., -100, -100] # 填充位置忽略损失 }
- 作用:将人类可理解的文本转换为模型可计算的数字序列,同时通过
attention_mask
和labels
告诉模型“关注什么”和“学习什么”。
阶段5:转换为模型输入格式——张量化适配批处理
- 输入:阶段4生成的
input_ids
、attention_mask
、labels
列表。 - 处理逻辑:将列表格式转换为PyTorch张量(
torch.Tensor
),并统一形状。 - 输出格式:张量字典,每个字段的形状为
(max_length,)
(单样本)或(batch_size, max_length)
(批次)。{"input_ids": tensor([101, 2345, 678, ..., 0, 0]), # 形状: (max_length,)"attention_mask": tensor([1, 1, 1, ..., 0, 0]),"labels": tensor([101, 2345, 678, ..., -100, -100]) }
- 作用:适配模型的输入要求(模型仅接受张量格式),便于批量计算(如并行处理多个样本)。
阶段 | 数据形态变化 | 核心目标 |
---|---|---|
原始加载 | 多字段字典(保留原始交互) | 完整保留数据原貌 |
对话解析 | 提取轮次、角色、内容 | 理解数据结构,为格式化做准备 |
模板格式化 | 多轮messages →统一带角色标记的字符串 | 让模型识别角色和上下文边界 |
Tokenization | 字符串→input_ids /mask /labels | 将文本转为模型可理解的数字序列 |
张量转换 | 列表→PyTorch张量 | 适配模型输入格式,支持批处理和反向传播 |
load_dataset
函数详解
load_dataset
是 Hugging Face datasets
库的核心函数,用于加载数据集,支持从 Hugging Face Hub(云端仓库)、本地文件 或 自定义格式 加载,兼容多种数据格式(CSV、JSON、Parquet、图片、音频等)。其核心作用是简化数据集的获取、处理和缓存流程,返回可直接用于模型训练/评估的 Dataset
或 DatasetDict
对象。
参数说明(表格)
参数名称 | 类型 | 默认值 | 描述 |
---|---|---|---|
path | str | 无(必填) | 数据集来源路径或名称,决定加载方式: - Hub 仓库名(如 'cornell-movie-review-data/rotten_tomatoes' ):从云端加载;- 本地目录(如 './data' ):从本地文件夹加载;- 数据格式(如 'csv' ):配合 data_files 加载指定格式文件。 |
name | str (可选) | None | 数据集配置名称(部分数据集有多个子配置,如 'nyu-mll/glue' 的 'sst2' 子任务)。 |
data_dir | str (可选) | None | 本地数据目录。若指定,且 data_files 为 None ,则加载该目录下所有文件(等效于 data_files=os.path.join(data_dir, **) )。 |
data_files | str / Sequence[str] / Mapping[str, Union[str, Sequence[str]]] (可选) | None | 具体数据文件路径: - 单文件(如 'train.csv' );- 多文件列表(如 ['train1.csv', 'train2.csv'] );- 拆分映射(如 {'train': 'train.csv', 'test': 'test.csv'} ),指定文件对应的数据拆分(train/test)。 |
split | str / Split (可选) | None | 加载的数据集拆分(如 'train' 、'test' 、'train+test' )。- 若为 None :返回包含所有拆分的 DatasetDict ;- 若指定:返回单个 Dataset 。 |
cache_dir | str (可选) | ~/.cache/huggingface/datasets | 缓存目录,用于存储下载/处理后的数据集(避免重复下载)。 |
features | Features (可选) | None | 自定义数据集特征结构(如指定字段类型为文本、整数等),覆盖自动推断的特征。 |
download_config | DownloadConfig (可选) | None | 下载配置(如超时时间、代理等),控制数据下载的细节。 |
download_mode | DownloadMode / str (可选) | REUSE_DATASET_IF_EXISTS | 下载模式: - REUSE_DATASET_IF_EXISTS (默认):若本地有缓存,直接复用;- FORCE_REDOWNLOAD :强制重新下载;- REUSE_CACHE_IF_EXISTS :复用缓存,若缓存损坏则重新下载。 |
verification_mode | VerificationMode / str (可选) | BASIC_CHECKS | 数据集校验模式: - BASIC_CHECKS (默认):基础校验(文件存在、大小匹配);- ALL_CHECKS :完整校验(含哈希值匹配);- NO_CHECKS :不校验。若 save_infos=True ,默认升级为 ALL_CHECKS 。 |
keep_in_memory | bool (可选) | None | 是否将数据集加载到内存: - None (默认):自动判断(小数据集加载到内存,大数据集从磁盘读取);- True /False :强制加载/不加载到内存。 |
save_infos | bool (可选) | False | 是否保存数据集元信息(如校验和、大小、拆分信息)到缓存目录。 |
revision | str / Version (可选) | None | Hub 数据集的版本(如分支名 'main' 、commit SHA、标签),用于加载特定版本的数据集。 |
token | bool / str (可选) | None | 访问私有 Hub 数据集的令牌: - True :从 ~/.huggingface 读取令牌;- 字符串:直接传入令牌。 |
streaming | bool (可选) | False | 是否流式加载: - False (默认):下载并缓存完整数据集;- True :不下载,实时流式读取(适合超大数据集,返回 IterableDataset )。 |
num_proc | int (可选) | None | 并行处理的进程数(用于加速本地数据集的下载和预处理),默认禁用多进程。 |
storage_options | dict (可选) | None | 实验性参数,传递给文件系统后端的配置(如云端存储的访问密钥)。 |
trust_remote_code | bool (可选) | False | 是否信任 Hub 上的数据集脚本(若数据集包含自定义脚本)。 - True :执行远程脚本(仅信任可靠仓库);- False (默认):拒绝执行远程脚本。 |
**config_kwargs | 额外关键字参数 | 无 | 传递给数据集构建器(DatasetBuilder )的额外配置参数。 |
返回值说明
- 若
streaming=False
(默认):
- 若split
为None
:返回DatasetDict
(包含所有数据拆分,如{'train': Dataset, 'test': Dataset}
);
- 若split
指定:返回单个Dataset
(对应拆分的数据)。 - 若
streaming=True
:
- 若split
为None
:返回IterableDatasetDict
;
- 若split
指定:返回IterableDataset
(流式迭代器,不缓存完整数据)。
核心特点
- 多来源支持:兼容 Hub 云端仓库、本地文件、自定义格式(图片、音频等)。
- 自动化处理:自动解析数据格式、推断特征类型、缓存处理结果(避免重复计算)。
- 灵活拆分:支持加载特定拆分(如仅训练集)或组合拆分(如
'train+validation'
)。 - 流式加载:适合超大数据集(不占用本地存储空间,实时读取)。
- 版本控制:通过
revision
参数指定 Hub 数据集的特定版本,确保可复现性。
示例场景
- 从 Hub 加载公开数据集:
load_dataset('rotten_tomatoes', split='train')
- 从本地 CSV 文件加载:
load_dataset('csv', data_files='./train.csv')
- 流式加载超大数据集:
load_dataset('large_dataset', streaming=True)
- 加载指定版本的数据集:
load_dataset('my_dataset', revision='v1.0', split='test')