当前位置: 首页 > news >正文

打造可扩展的大模型训练框架:支持PEFT微调、分布式训练与TensorBoard可视化

本文将手把手带你构建一个强大的大模型训练脚本,具备模块化、可复用、低比特微调(LoRA)、多卡训练和可视化能力,适用于生产环境或竞赛场景。代码基于 HuggingFace Transformers 和 PEFT 框架开发。

项目所需模型获取:https://pan.quark.cn/s/10e99aad42bb

数据集获取:https://pan.quark.cn/s/eb7a9ef0740d

🧱 一、整体框架设计概览

我们希望实现的训练框架具备如下特性:

  • ✅ 模块化,易于维护

  • ✅ 支持多卡分布式训练(兼容 torchrun 启动)

  • ✅ 支持 PEFT 微调(LoRA)与 8bit 低精度加载

  • ✅ 自定义 Trainer 支持 TensorBoard 可视化

  • ✅ 可读取 JSON 格式数据集并处理为 HuggingFace Dataset

代码结构分为以下几个核心部分:

  • 分布式训练初始化 setup_distributed()

  • 数据预处理与构建 convert_json_to_dataset()

  • 模型和 Tokenizer 加载、LoRA 配置

  • 自定义 Trainer 实现 TensorBoard 记录

  • 主函数串联所有流程

🔌 二、初始化分布式训练环境

我们通过 torch.distributedtorchrun 启动分布式训练环境:

import torch.distributed as dist
import osdef setup_distributed():local_rank = int(os.environ.get("LOCAL_RANK", -1))world_size = int(os.environ.get("WORLD_SIZE", 1))if local_rank != -1:dist.init_process_group(backend="nccl")torch.cuda.set_device(local_rank)return local_rank, world_size

🧹 三、数据集读取与预处理

json 文件转化为 HuggingFace Dataset,并构造 prompt:

from datasets import Dataset
import json
import osdef preprocess_function(example, tokenizer):prompt = f"### 用户指令:{example['instruction']}\n"if example.get("input"):prompt += f"### 补充输入:{example['input']}\n"prompt += f"### 输出:{example['output']}"return tokenizer(prompt, truncation=True, max_length=512, padding='max_length')def convert_json_to_dataset(json_file_path, save_path, tokenizer):with open(json_file_path, 'r', encoding='utf-8') as file:data = json.load(file)dataset = Dataset.from_list(data)dataset = dataset.map(lambda x: preprocess_function(x, tokenizer))os.makedirs(save_path, exist_ok=True)dataset.save_to_disk(save_path)return dataset

🤖 四、模型与Tokenizer加载(支持8bit & LoRA)

我们使用 transformers 加载模型,同时支持 bitsandbytes 进行低比特量化加载。

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_trainingdef prepare_model_and_tokenizer(model_path, local_rank):device_map = {"": local_rank} if local_rank != -1 else "auto"quant_config = BitsAndBytesConfig(load_in_8bit=True)tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained(model_path,quantization_config=quant_config,device_map=device_map,trust_remote_code=True,)model = prepare_model_for_kbit_training(model)return model, tokenizer

📊 五、自定义 Trainer 实现训练可视化

我们继承 transformers.Trainer 并重写 compute_loss 方法,在其中集成 TensorBoard 记录:

from transformers import Trainer
from torch.utils.tensorboard import SummaryWriter
import torchclass CustomTrainer(Trainer):def __init__(self, *args, tensorboard_dir=None, **kwargs):super().__init__(*args, **kwargs)self.tensorboard_dir = tensorboard_dirself.writer = SummaryWriter(log_dir=tensorboard_dir) if tensorboard_dir else Nonedef compute_loss(self, model, inputs, return_outputs=False):loss = super().compute_loss(model, inputs, return_outputs)if self.writer and self.state.global_step % self.args.logging_steps == 0:step = self.state.global_stepself.writer.add_scalar("train/loss", loss.item(), step)self.writer.add_scalar("train/lr", self._get_learning_rate(), step)self.writer.add_scalar("train/batch_size", self.args.per_device_train_batch_size, step)self.writer.add_scalar("train/gpu_memory", torch.cuda.max_memory_allocated() / 1e9, step)grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e9)self.writer.add_scalar("train/grad_norm", grad_norm, step)return loss

🧵 六、训练主程序整合

将所有模块串联到主函数中,一键启动训练:

from transformers import TrainingArguments
from peft import get_peft_model, LoraConfig, TaskTypeif __name__ == "__main__":from argparse import ArgumentParserparser = ArgumentParser()parser.add_argument('--model_path', type=str, required=True)parser.add_argument('--json_path', type=str, required=True)parser.add_argument('--data_dir', type=str, default='./dataset')parser.add_argument('--output_dir', type=str, default='./output')parser.add_argument('--tensorboard_dir', type=str, default='./runs')parser.add_argument('--max_steps', type=int, default=1000)parser.add_argument('--per_device_train_batch_size', type=int, default=2)parser.add_argument('--gradient_accumulation_steps', type=int, default=4)parser.add_argument('--logging_steps', type=int, default=10)args = parser.parse_args()local_rank, world_size = setup_distributed()model, tokenizer = prepare_model_and_tokenizer(args.model_path, local_rank)dataset = convert_json_to_dataset(args.json_path, args.data_dir, tokenizer)peft_config = LoraConfig(r=8,lora_alpha=16,target_modules=["q_proj", "v_proj"],lora_dropout=0.1,bias="none",task_type=TaskType.CAUSAL_LM,)model = get_peft_model(model, peft_config)training_args = TrainingArguments(output_dir=args.output_dir,per_device_train_batch_size=args.per_device_train_batch_size,gradient_accumulation_steps=args.gradient_accumulation_steps,logging_steps=args.logging_steps,max_steps=args.max_steps,save_strategy="steps",save_steps=200,report_to="none",remove_unused_columns=False,fp16=True,ddp_find_unused_parameters=False if local_rank != -1 else None,)trainer = CustomTrainer(model=model,args=training_args,train_dataset=dataset,tokenizer=tokenizer,tensorboard_dir=args.tensorboard_dir)trainer.train()

✅ 七、结语与后续建议

通过本项目,你已经拥有一个可快速扩展的大模型训练脚本。

完整代码获取:https://pan.quark.cn/s/4b6fcb9469f1

模型后续训练和部署教程请跳转:从训练到部署:基于 Qwen2.5 和 LoRA 的轻量化中文问答系统全流程实战-CSDN博客

未来可扩展方向:

  • 增加评估与测试模块

  • 支持 DeepSpeed、FSDP 超大模型训练

  • 自动化实验日志收集与模型版本控制

📌 Star 收藏不迷路,转发让更多人受益!

相关文章:

  • go语言学习 第5章:函数
  • 如何选择合适的embedding模型用于非英文语料
  • 【PmHub面试篇】PmHub 整合 TransmittableThreadLocal(TTL)缓存用户数据面试专题解析
  • 基于Gemini 2.5 Pro打造的AI智能体CanvasX上线,绘制常见图表(折线图、柱状图等),国内直接使用
  • [Java 基础]对象,膜具倒出来的
  • 微信小程序实现运动能耗计算
  • 12306高并发计算架构揭秘:Apache Geode 客户端接入与实践
  • webPack基本使用步骤
  • Neo4j 监控全解析:原理、技术、技巧与最佳实践
  • 【Linux系列】rsync命令详解与实践
  • 深入理解C#中的Web API:构建现代化HTTP服务的完整指南
  • BERT:让AI真正“读懂”语言的革命
  • Vue指令修饰符、v-bind对样式控制的增强、computed计算属性、watch监视器
  • 什么是预构建,Vite中如何使用预构建
  • Openlayers从入门到入坟
  • 【conda配置深度学习环境】
  • [Java 基础]抽象类和接口
  • 【C/C++】析构函数好玩的用法:~Derived() override
  • MCP与检索增强生成(RAG):AI应用的强大组合
  • 卫星的“太空陀螺”:反作用轮如何精准控制姿态?
  • 新闻类网站怎么做seo/宁波网站推广方案
  • 创建网站的目的是什么/新乡seo推广
  • 慈利做网站在哪里/今天发生的重大新闻5条
  • 外贸b2b网站建设公司/网上销售方法
  • 榆林市建设局网站/百度平台商家
  • 网站支付链接怎么做的/网络营销实施方案