当前位置: 首页 > news >正文

unsloth微调gemma3图文代码简析

代码使用了unsloth gemma3-4B的微调示例。
加载本地已经下载好的模型,使用了bnb 4bit量化,加载方便。

# 用户部分代码model, processor = FastVisionModel.from_pretrained(model_name = "/data/……/……/unsloth/gemma-3-4b-it-bnb-4bit",load_in_4bit = True,  # 4 bit quantization to reduce memory)

unsloth 加载模型FastVisionModel.from_pretrained函数逻辑:

        # unsloth FastVisionModel.from_pretrained函数检查模型是否包含vision模块model_config = AutoConfig.from_pretrained(model_name,token = token,trust_remote_code = trust_remote_code,)……# Check if VLMis_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures)is_vlm = is_vlm or hasattr(model_config, "vision_config")if auto_model is None:auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLMmodel, tokenizer = FastBaseModel.from_pretrained(model_name        = model_name,max_seq_length    = max_seq_length,dtype             = _get_dtype(dtype),load_in_4bit      = load_in_4bit,load_in_8bit      = load_in_8bit,full_finetuning   = full_finetuning,token             = token,device_map        = device_map,trust_remote_code = trust_remote_code,revision          = revision if not is_peft else None,model_types       = model_types,tokenizer_name    = tokenizer_name,auto_model        = auto_model,use_gradient_checkpointing = use_gradient_checkpointing,supports_sdpa     = supports_sdpa,whisper_language  = whisper_language,whisper_task      = whisper_task,*args, **kwargs,)

在内层FastBaseModel.from_pretrained判断和加载模型,识别出是否需要适配视觉模型处理器

        # unsloth FastBaseModel.from_pretrained# 这行判断是否为为vlm模型,使用对应的处理器加载函数is_vlm = (auto_model is AutoModelForVision2Seq)is_whisper = (whisper_language is not None and whisper_task is not None)auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizerif (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"):tokenizer = auto_processor.from_pretrained(tokenizer_name,padding_side = "right",token        = token,language     = whisper_language,task         = whisper_task,)else:tokenizer = auto_processor.from_pretrained(tokenizer_name,padding_side = "right",token        = token,)

用户配置模型lora参数

# 用户代码部分
model = FastVisionModel.get_peft_model(model,finetune_vision_layers     = True, # Turn off for just text!finetune_language_layers   = True,  # Should leave on!finetune_attention_modules = True,  # Attention good for GRPOfinetune_mlp_modules       = True,  # SHould leave on always!r = 16,           # Larger = higher accuracy, but might overfitlora_alpha = 16,  # Recommended alpha == r at leastlora_dropout = 0,bias = "none",random_state = 3407,use_rslora = False,               # We support rank stabilized LoRAloftq_config = None,               # And LoftQtarget_modules = "all-linear",    # Optional now! Can specify a list if neededmodules_to_save=["lm_head","embed_tokens",],)

模型lora网络配置加载内部逻辑:

# unsloth FastVisionModel.get_peft_model内部调用函数,在选出一些与开启module训练相关的模块名称
def get_peft_regex(model,finetune_vision_layers     : bool = True,finetune_language_layers   : bool = True,finetune_attention_modules : bool = True,finetune_mlp_modules       : bool = True,target_modules             : List[str] = None,vision_tags                : List[str] = ["vision", "image", "visual", "patch",],language_tags              : List[str] = ["language", "text",],attention_tags             : List[str] = ["self_attn", "attention", "attn",],mlp_tags                   : List[str] = ["mlp", "feed_forward", "ffn", "dense",],
) -> str:……# 在选出一些与开启module训练相关的模块名称regex_model_parts = []if finetune_vision_layers:     regex_model_parts += vision_tagsif finetune_language_layers:   regex_model_parts += language_tagsregex_components  = []if finetune_attention_modules: regex_components  += attention_tagsif finetune_mlp_modules:       regex_components  += mlp_tagsregex_model_parts = "|".join(regex_model_parts)regex_components  = "|".join(regex_components)

之后被选出的训练模块名称放到lora配置

   # unsloth 内部函数lora_config_dict = {"r"                 : r,"lora_alpha"        : lora_alpha,"target_modules"    : target_modules, # get_peft_regex 函数的返回"target_parameters" : kwargs.get("target_parameters", None),"lora_dropout"      : lora_dropout,"bias"              : bias,"task_type"         : task_type,"use_rslora"        : use_rslora,"init_lora_weights" : init_lora_weights,"loftq_config"      : loftq_config,}lora_config = LoraConfig(**{k:v for k,v in lora_config_dict.items() if k in LoraConfig.__doc__},)model = prepare_model_for_kbit_training(model,use_gradient_checkpointing = use_gradient_checkpointing,)model = _get_peft_model(model, lora_config)

配置训练module层,给对应的层打开梯度更新,关闭不需要的层

# unsloth prepare_model_for_kbit_training 内部调用函数
def prepare_model_for_training(model                      : Any,use_gradient_checkpointing : Optional = "unsloth",use_reentrant              : Optional[bool] = True,full_finetuning            : Optional[bool] = False,train_layernorms           : Optional[bool] = False,train_embedding            : Optional[bool] = False,train_lm_head              : Optional[bool] = False,float32_mixed_precision    : Optional[bool] = True,
) -> Any:……for name, param in model.named_parameters():upcast = Falserequires_grad = Falseif not full_finetuning:if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:upcast = Truerequires_grad = Trueelse:requires_grad = Falseelse:if train_layernorms and ("norm." in name or "_layernorm" in name):requires_grad = Trueupcast = True # Must upcast layernorms to float32if train_embedding and ("embed_tokens" in name or "embedding" in name):requires_grad = Trueupcast = False # Can leave in bfloat16if train_lm_head and ("lm_head" in name):requires_grad = Trueupcast = False # Can leave in bfloat16else:requires_grad = Trueupcast = False # Can leave in bfloat16pass# Set training or notif requires_grad:param.requires_grad_(True)else:param.requires_grad_(False)# Upcast to float32 if neededif requires_grad:name = name.replace("base_model", "model", 1)while re.search(r'\.(\d+)\.', name) is not None:name = re.sub(r'\.(\d+)\.', r'[\1].', name)name = name.replace(".weight", "", 1)dtype = torch.float32 if upcast else mixed_precision_dtypetry:# Try original nameexec(f"{name}.to({str(dtype)})")except:# Maybe model.modelexec(f"model.{name}.to({str(dtype)})")passif ('norm.' in name or '_layernorm' in name) and os.environ.get("UNSLOTH_UPCAST_LAYERNORM", "0") == "1":try:name = name.replace("base_model", "model", 1)while re.search(r'\.(\d+)\.', name) is not None:name = re.sub(r'\.(\d+)\.', r'[\1].', name)name = name.replace(".weight", "", 1)# Try original nameexec(f"{name}.to({str(torch.float32)})")except:# Maybe model.modelexec(f"model.{name}.to({str(torch.float32)})")

加载huggingface的图文训练数据集

# 用户代码部分
def formatting_prompts_func(examples):convos = examples["conversations"]texts = [processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]return { "text" : texts, }def convert_to_conversation(sample):instruction = "Write the LaTeX representation for this image."conversation = [{"role": "user","content": [{"type": "text", "text": instruction},{"type": "image", "image": sample["image"]},],},{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},]return {"messages": conversation}
……
……
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")converted_dataset = [convert_to_conversation(sample) for sample in dataset]processor = get_chat_template(processor,"gemma-3")

根据gemma-3类型判断返回对话模版

# unsloth 内部函数
gemma3_ollama = 
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if or (eq .Role "user") (eq .Role "system") }}<start_of_turn>user
{{ .Content }}<end_of_turn>
{{ if $last }}<start_of_turn>model
{{ end }}
{{- else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}{{ if not $last }}<end_of_turn>
{{ end }}
{{- end }}
{{- end }}"""
PARAMETER stop "<end_of_turn>"
PARAMETER stop "<eos>"
PARAMETER temperature 0.1
PARAMETER min_p 0.0
PARAMETER top_k 64
PARAMETER top_p 0.95
PARAMETER num_predict 32768
'''
gemma3_template_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)# get_chat_template 内部
def get_chat_template()chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]…… # 还有其余pad相关处理,主要获取模板

开启训练模式,与训练配置

# 用户代码部分
FastVisionModel.for_training(model) # Enable for training!trainer = SFTTrainer(model=model,train_dataset=converted_dataset,processing_class=processor.tokenizer,data_collator=UnslothVisionDataCollator(model, processor),args = SFTConfig(per_device_train_batch_size = 1,gradient_accumulation_steps = 4,gradient_checkpointing = True,# use reentrant checkpointinggradient_checkpointing_kwargs = {"use_reentrant": False},max_grad_norm = 0.3,              # max gradient norm based on QLoRA paperwarmup_ratio = 0.03,max_steps = 3,#num_train_epochs = 2,          # Set this instead of max_steps for full training runslearning_rate = 2e-4,logging_steps = 1,save_strategy="steps",optim = "adamw_torch_fused",weight_decay = 0.01,lr_scheduler_type = "cosine",seed = 3407,output_dir = "outputs",report_to = "none",             # For Weights and Biases# You MUST put the below items for vision finetuning:remove_unused_columns = False,dataset_text_field = "",dataset_kwargs = {"skip_prepare_dataset": True},max_length = 2048,))trainer_stats = trainer.train()

保存训练好的lora模型

# 用户代码部分model.save_pretrained("gemmavision-3",'/……/……/testlora')  # Local savingprocessor.save_pretrained("gemmavision-3")

文章转载自:

http://guBMh5NO.rzdpd.cn
http://X26SAIpx.rzdpd.cn
http://XDtXjrJi.rzdpd.cn
http://P6BLgdDY.rzdpd.cn
http://S2yU5KJq.rzdpd.cn
http://3VrCBxS4.rzdpd.cn
http://5EvQ04EE.rzdpd.cn
http://jBfNjNTf.rzdpd.cn
http://aA3nzY39.rzdpd.cn
http://pyv3Pu5s.rzdpd.cn
http://jfr7UCoL.rzdpd.cn
http://8h5CC1sq.rzdpd.cn
http://ZRHyc8cm.rzdpd.cn
http://FQ7BiK6r.rzdpd.cn
http://GIiqfV69.rzdpd.cn
http://ZucbAXPO.rzdpd.cn
http://GllpZFmu.rzdpd.cn
http://JwuZnCJt.rzdpd.cn
http://0zVeXMqT.rzdpd.cn
http://xOSSiDZE.rzdpd.cn
http://EJYbIHZ7.rzdpd.cn
http://WeSGMkOx.rzdpd.cn
http://SQ7Smfsp.rzdpd.cn
http://l5tPNTnd.rzdpd.cn
http://0omoK8P0.rzdpd.cn
http://Hr4d0UY8.rzdpd.cn
http://xpSI2eJm.rzdpd.cn
http://kHuYAEqf.rzdpd.cn
http://DXEsx2Bm.rzdpd.cn
http://XhaWGK4s.rzdpd.cn
http://www.dtcms.com/a/378987.html

相关文章:

  • 【ECharts ✨】ECharts 自适应图表布局:适配不同屏幕尺寸,提升用户体验!
  • wpf依赖注入驱动的 MVVM实现(含免费源代码demo)
  • Python的f格式
  • 技术视界 | 末端执行器:机器人的“手”,如何赋予机器以生命?
  • 从零开始使用 axum-server 构建 HTTP/HTTPS 服务
  • 简直有毒!索伯ACL撕裂,雷霆四年报销三个新秀!
  • 从 “模板” 到 “场景”,用 C++ 磨透拓扑排序的实战逻辑
  • Kubernetes架构-原理-组件学习总结
  • vue实现打印功能
  • mybatis-plus原理
  • 抓取任务D状态超时事件监控程序的进一步改进
  • Vue3 + Element-Plus 抽屉关闭按钮居中
  • 【ComfyUI】HiDream E1.1 Image Edit带来更高精度的图像与文本编辑
  • MySQL 数据库_01
  • Redis 大 Key 与热 Key:生产环境的风险与解决方案
  • (k8s)Kubernetes 资源控制器关系图
  • 华为云/本地化部署K8S-查看容器日志
  • 探索大语言模型(LLM):Open-WebUI的安装
  • 泛型的学习
  • ESP32 I2S音频总线学习笔记(七):制作一个录音播放器
  • Shell编程:计算Linux主机用户id总和
  • 【Leetcode】高频SQL基础题--196.删除重复的电子邮箱
  • SpreadJS V18.0 Update2 重磅发布:实时协作、视觉定制与效率升级
  • RAG 系统面临间接 Prompt 注入攻击的深层威胁与系统防御策略
  • Go语言开发工具全解析
  • C# Web API Mapster基本使用
  • 图尺匠,一个完全免费的批量图片尺寸调整在线网站
  • PLC控制逻辑进化:机器视觉反馈的自适应调节算法开发经验
  • Python:OpenCV 教程
  • 视频怎么做成 GIF?用 oCam 一键录制 GIF 动画超简单