当前位置: 首页 > news >正文

通义千问模型微调——swift框架

1.创建环境

服务器CUDA Version: 12.2

conda create -n lora_qwen python=3.10 -y 
conda activate lora_qwen 
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y

1.1环境搭建

本文使用swift进行微调,所以先下载swift,以及一些必要的packages

git clone https://github.com/modelscope/ms-swift.git
pip install transformers==4.49.0 
pip install pyav qwen_vl_utils 
pip install numpy==1.22.4 
pip install modelscope

1.2模型下载

使用modelscope下载指定模型,其中:

--model表示模型名称,可在modelscope官网找到

--local_dir代表模型下载地址

运行下面的命令,模型会下载到:./Qwen/Qwen2.5-VL-7B-Instruct目录下

modelscope download --model Qwen/Qwen2.5-VL-7B-Instruct --local_dir ./

下面脚本用于和模型进行对话,可以简单测试一下模型是否能够使用

CUDA_VISIBLE_DEVICES=1 swift infer --model_type qwen2_5_vl --ckpt_dir ./Qwen/Qwen2.5-VL-7B-Instruct

1.3数据集准备

下方是数据集格式,保存类型为.jsonl

[
    {
        "query": "OCR一下<image>",
        "response": "朵拉童衣",
        "images": [
            "datasets/lora_qwen/train/billboard_00001_010_朵拉童衣.jpg"
        ]
    },
    {
        "query": "OCR一下<image>",
        "response": "童衣雜貨舖",
        "images": [
            "datasets/lora_qwen/train/billboard_00002_010_童衣雜貨舖.jpg"
        ]
    },...
]

2.微调

2.1采用LoRA进行微调

对文件夹中之前下载的ms-swift-main/examples/train/multimodal/ocr.sh进行修改

# 20GB
CUDA_VISIBLE_DEVICES=0,1 \
MAX_PIXELS=1003520 \
swift sft \
    --model ./Qwen/Qwen2.5-VL-7B-Instruct \
    --model_type qwen2_5_vl \
    --dataset ./datatsets/train.jsonl \
    --val_dataset ./datatsets/val.jsonl \
    --train_type lora \
    --torch_dtype bfloat16 \
    --num_train_epochs 100 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-4 \
    --lora_rank 64 \
    --lora_alpha 64 \
    --target_modules all-linear \
    --freeze_vit true \
    --gradient_accumulation_steps 16 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 10 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4
  • 常用参数解释:

--model:原模型的权重地址

--dataset:训练集的数据地址

--val_dataset:验证集的数据地址

--train_type:全参数训练(full) 或 LoRA微调训练(lora)

--num_train_epochs:总共要训练的轮数

--per_device_train_batch_size:训练阶段batchsize大小,根据显存大小来设置

--per_device_eval_batch_size:验证阶段batchsize大小,根据显存大小来设置

--learning_rate:学习率,一般设为0.0001或0.00001

--target_modules:需要做微调的目标模块,all-linear表示所有的线形层,也就是Attention和FeedForward层

--freeze_vit:一般设为true,不微调视觉编码器,只微调LLM部分

2.2使用Transformer进行推理

import os
import re
import torch
from PIL import Image

from datasets import Dataset
from modelscope import AutoTokenizer
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    Trainer, TrainingArguments,
    Seq2SeqTrainer, Seq2SeqTrainingArguments, 
    DataCollatorForSeq2Seq,
)
from qwen_vl_utils import process_vision_info

rewrite_print = print
def print(save_txt, *arg, **kwargs):
    rewrite_print(*arg, **kwargs)
    rewrite_print(*arg, **kwargs, file=open(save_txt, "a+", encoding="utf-8"))

def process_func(model, img_path, input_content):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img_path},
                {"type": "text", "text": input_content},
            ],
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    print(save_txt_path, img_path)
    print(save_txt_path, output_text[0])
    print(save_txt_path, '\n')

def get_lora_model(model_path, lora_model_path):
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    model.enable_input_require_grads()

    config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules="model\..*layers\.\d+\.(self_attn\.(q_proj|k_proj|v_proj|o_proj)|mlp\.(gate_proj|up_proj|down_proj))",
        inference_mode=True,
        r=64,
        lora_alpha=64,
        lora_dropout=0.05,
        bias="none",
    )

    peft_model = PeftModel.from_pretrained(model, model_id=lora_model_path, config=config)
    return peft_model

if __name__ == '__main__':
    save_txt_path = 'log.txt'

    model_path = "./Qwen2.5-VL-7B-Instruct"
    lora_model_path = "./output/v2-20250228-202446/checkpoint-900"
    lora_model = get_lora_model(model_path, lora_model_path)
    processor = AutoProcessor.from_pretrained(model_path)
    
    img_path = "图片路径"
    prompt = "OCR一下"
    process_func(lora_model, img_path, prompt)

3.实验参数情况

模型微调显存:30G左右(主要看数据集,图片越大,prompt,answer越多,占用显存越多);

模型微调后推理:20G左右;

相关文章:

  • Python第六章02:列表操作——下标索引
  • JVM-JAVA编译到执行全过程
  • SQL Server性能分析利器:SET STATISTICS TIME ON 详解与实战案例
  • Unity导出WebGL,无法显示中文
  • 在 Vue.js 中,使用 proxy.$refs.waybillNumberRef.focus() 获取焦点不生效
  • 实验5:Vuex状态管理
  • 学习C2CRS Ⅴ (Conversational Recommender System)
  • 30天学习Java第六天——super关键字
  • MySQL实现全量同步和增量同步到SQL Server或其他关系型库
  • vue3计算当前日期往前推一个月的日期,当前日期往前推7天
  • JVAV面试-静态代理动态代理
  • 大模型知识蒸馏:技术演进与未来展望
  • 借助vite来优化前端性能
  • 2025年Postman的五大替代工具
  • Linux生成自签名证书
  • ThreadLocal底层原理,内存泄露问题,以及如何在项目中使用这个关键字(总结)
  • 互功率谱 cpsd
  • HTTP 失败重试(重发)方案
  • 【小白向】Word|Word怎么给公式标号、调整公式字体和花括号对齐
  • 使用 OpenAI 的 Node.js 通过 Ollama 在本地运行 DeepSeek R1
  • 彩票网站开发系统/域名注册 万网
  • 企业网站 联系我们/百度seo网站优化
  • 濮阳建站公司流程/数据指数
  • 网站建设需要哪些工作室/重庆百度推广
  • 护肤网站的功能设计/长沙网站定制
  • 微网站建设讯息/关键词优化价格表