unsloth微调gemma3图文代码简析
代码使用了unsloth gemma3-4B的微调示例。
加载本地已经下载好的模型,使用了bnb 4bit量化,加载方便。
# 用户部分代码model, processor = FastVisionModel.from_pretrained(model_name = "/data/……/……/unsloth/gemma-3-4b-it-bnb-4bit",load_in_4bit = True, # 4 bit quantization to reduce memory)
unsloth 加载模型FastVisionModel.from_pretrained函数逻辑:
# unsloth FastVisionModel.from_pretrained函数检查模型是否包含vision模块model_config = AutoConfig.from_pretrained(model_name,token = token,trust_remote_code = trust_remote_code,)……# Check if VLMis_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures)is_vlm = is_vlm or hasattr(model_config, "vision_config")if auto_model is None:auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLMmodel, tokenizer = FastBaseModel.from_pretrained(model_name = model_name,max_seq_length = max_seq_length,dtype = _get_dtype(dtype),load_in_4bit = load_in_4bit,load_in_8bit = load_in_8bit,full_finetuning = full_finetuning,token = token,device_map = device_map,trust_remote_code = trust_remote_code,revision = revision if not is_peft else None,model_types = model_types,tokenizer_name = tokenizer_name,auto_model = auto_model,use_gradient_checkpointing = use_gradient_checkpointing,supports_sdpa = supports_sdpa,whisper_language = whisper_language,whisper_task = whisper_task,*args, **kwargs,)
在内层FastBaseModel.from_pretrained判断和加载模型,识别出是否需要适配视觉模型处理器
# unsloth FastBaseModel.from_pretrained# 这行判断是否为为vlm模型,使用对应的处理器加载函数is_vlm = (auto_model is AutoModelForVision2Seq)is_whisper = (whisper_language is not None and whisper_task is not None)auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizerif (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"):tokenizer = auto_processor.from_pretrained(tokenizer_name,padding_side = "right",token = token,language = whisper_language,task = whisper_task,)else:tokenizer = auto_processor.from_pretrained(tokenizer_name,padding_side = "right",token = token,)
用户配置模型lora参数
# 用户代码部分
model = FastVisionModel.get_peft_model(model,finetune_vision_layers = True, # Turn off for just text!finetune_language_layers = True, # Should leave on!finetune_attention_modules = True, # Attention good for GRPOfinetune_mlp_modules = True, # SHould leave on always!r = 16, # Larger = higher accuracy, but might overfitlora_alpha = 16, # Recommended alpha == r at leastlora_dropout = 0,bias = "none",random_state = 3407,use_rslora = False, # We support rank stabilized LoRAloftq_config = None, # And LoftQtarget_modules = "all-linear", # Optional now! Can specify a list if neededmodules_to_save=["lm_head","embed_tokens",],)
模型lora网络配置加载内部逻辑:
# unsloth FastVisionModel.get_peft_model内部调用函数,在选出一些与开启module训练相关的模块名称
def get_peft_regex(model,finetune_vision_layers : bool = True,finetune_language_layers : bool = True,finetune_attention_modules : bool = True,finetune_mlp_modules : bool = True,target_modules : List[str] = None,vision_tags : List[str] = ["vision", "image", "visual", "patch",],language_tags : List[str] = ["language", "text",],attention_tags : List[str] = ["self_attn", "attention", "attn",],mlp_tags : List[str] = ["mlp", "feed_forward", "ffn", "dense",],
) -> str:……# 在选出一些与开启module训练相关的模块名称regex_model_parts = []if finetune_vision_layers: regex_model_parts += vision_tagsif finetune_language_layers: regex_model_parts += language_tagsregex_components = []if finetune_attention_modules: regex_components += attention_tagsif finetune_mlp_modules: regex_components += mlp_tagsregex_model_parts = "|".join(regex_model_parts)regex_components = "|".join(regex_components)
之后被选出的训练模块名称放到lora配置
# unsloth 内部函数lora_config_dict = {"r" : r,"lora_alpha" : lora_alpha,"target_modules" : target_modules, # get_peft_regex 函数的返回"target_parameters" : kwargs.get("target_parameters", None),"lora_dropout" : lora_dropout,"bias" : bias,"task_type" : task_type,"use_rslora" : use_rslora,"init_lora_weights" : init_lora_weights,"loftq_config" : loftq_config,}lora_config = LoraConfig(**{k:v for k,v in lora_config_dict.items() if k in LoraConfig.__doc__},)model = prepare_model_for_kbit_training(model,use_gradient_checkpointing = use_gradient_checkpointing,)model = _get_peft_model(model, lora_config)
配置训练module层,给对应的层打开梯度更新,关闭不需要的层
# unsloth prepare_model_for_kbit_training 内部调用函数
def prepare_model_for_training(model : Any,use_gradient_checkpointing : Optional = "unsloth",use_reentrant : Optional[bool] = True,full_finetuning : Optional[bool] = False,train_layernorms : Optional[bool] = False,train_embedding : Optional[bool] = False,train_lm_head : Optional[bool] = False,float32_mixed_precision : Optional[bool] = True,
) -> Any:……for name, param in model.named_parameters():upcast = Falserequires_grad = Falseif not full_finetuning:if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:upcast = Truerequires_grad = Trueelse:requires_grad = Falseelse:if train_layernorms and ("norm." in name or "_layernorm" in name):requires_grad = Trueupcast = True # Must upcast layernorms to float32if train_embedding and ("embed_tokens" in name or "embedding" in name):requires_grad = Trueupcast = False # Can leave in bfloat16if train_lm_head and ("lm_head" in name):requires_grad = Trueupcast = False # Can leave in bfloat16else:requires_grad = Trueupcast = False # Can leave in bfloat16pass# Set training or notif requires_grad:param.requires_grad_(True)else:param.requires_grad_(False)# Upcast to float32 if neededif requires_grad:name = name.replace("base_model", "model", 1)while re.search(r'\.(\d+)\.', name) is not None:name = re.sub(r'\.(\d+)\.', r'[\1].', name)name = name.replace(".weight", "", 1)dtype = torch.float32 if upcast else mixed_precision_dtypetry:# Try original nameexec(f"{name}.to({str(dtype)})")except:# Maybe model.modelexec(f"model.{name}.to({str(dtype)})")passif ('norm.' in name or '_layernorm' in name) and os.environ.get("UNSLOTH_UPCAST_LAYERNORM", "0") == "1":try:name = name.replace("base_model", "model", 1)while re.search(r'\.(\d+)\.', name) is not None:name = re.sub(r'\.(\d+)\.', r'[\1].', name)name = name.replace(".weight", "", 1)# Try original nameexec(f"{name}.to({str(torch.float32)})")except:# Maybe model.modelexec(f"model.{name}.to({str(torch.float32)})")
加载huggingface的图文训练数据集
# 用户代码部分
def formatting_prompts_func(examples):convos = examples["conversations"]texts = [processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]return { "text" : texts, }def convert_to_conversation(sample):instruction = "Write the LaTeX representation for this image."conversation = [{"role": "user","content": [{"type": "text", "text": instruction},{"type": "image", "image": sample["image"]},],},{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},]return {"messages": conversation}
……
……
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")converted_dataset = [convert_to_conversation(sample) for sample in dataset]processor = get_chat_template(processor,"gemma-3")
根据gemma-3类型判断返回对话模版
# unsloth 内部函数
gemma3_ollama =
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if or (eq .Role "user") (eq .Role "system") }}<start_of_turn>user
{{ .Content }}<end_of_turn>
{{ if $last }}<start_of_turn>model
{{ end }}
{{- else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}{{ if not $last }}<end_of_turn>
{{ end }}
{{- end }}
{{- end }}"""
PARAMETER stop "<end_of_turn>"
PARAMETER stop "<eos>"
PARAMETER temperature 0.1
PARAMETER min_p 0.0
PARAMETER top_k 64
PARAMETER top_p 0.95
PARAMETER num_predict 32768
'''
gemma3_template_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)# get_chat_template 内部
def get_chat_template()chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]…… # 还有其余pad相关处理,主要获取模板
开启训练模式,与训练配置
# 用户代码部分
FastVisionModel.for_training(model) # Enable for training!trainer = SFTTrainer(model=model,train_dataset=converted_dataset,processing_class=processor.tokenizer,data_collator=UnslothVisionDataCollator(model, processor),args = SFTConfig(per_device_train_batch_size = 1,gradient_accumulation_steps = 4,gradient_checkpointing = True,# use reentrant checkpointinggradient_checkpointing_kwargs = {"use_reentrant": False},max_grad_norm = 0.3, # max gradient norm based on QLoRA paperwarmup_ratio = 0.03,max_steps = 3,#num_train_epochs = 2, # Set this instead of max_steps for full training runslearning_rate = 2e-4,logging_steps = 1,save_strategy="steps",optim = "adamw_torch_fused",weight_decay = 0.01,lr_scheduler_type = "cosine",seed = 3407,output_dir = "outputs",report_to = "none", # For Weights and Biases# You MUST put the below items for vision finetuning:remove_unused_columns = False,dataset_text_field = "",dataset_kwargs = {"skip_prepare_dataset": True},max_length = 2048,))trainer_stats = trainer.train()
保存训练好的lora模型
# 用户代码部分model.save_pretrained("gemmavision-3",'/……/……/testlora') # Local savingprocessor.save_pretrained("gemmavision-3")