当前位置: 首页 > news >正文

Swift实战(微调多模态模型Qwen2.5 vl 7B)

在这里插入图片描述
本教程利用Swift框架微调Qwen2.5 vl 7B模型,是用的数据集是OCR识别数据集,一共10万张图片。

一. 安装环境

尤其注意cuda版本,否则有些包安装不了
在这里插入图片描述

conda create -n swift3 python==3.10
# flash-attn对cuda版本有要求
pip install flash-attn
pip install auto_gptq optimum bitsandbytes timm
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e .

# 如果有需要,安装vllm ,对cuda版本有要求
pip install vllm

## 如果是qwen2.5-vl
pip install git+https://github.com/huggingface/transformers.git@9985d06add07a4cc691dc54a7e34f54205c04d40
pip install qwen_vl_utils

二. 数据准备

处理数据:

import os
import json


# 写入jsonl文件
def write_jsonl(data_list, filename):
    with open(filename, 'w', encoding='utf-8') as f:
        for item in data_list:
            # 将Python对象转换为JSON格式的字符串
            json_str = json.dumps(item, ensure_ascii=False)  
            f.write(json_str + '\n')


if __name__ == "__main__":
    img_dir = "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages"
    with open("LabelTrain.txt", "r") as f:
        data_list = []
        lines = f.readlines()
        for line in lines[:90000]:
            img_name, text = line.rstrip().split("\t")
            img_path = os.path.join(img_dir, img_name)
            data = {}
            data["query"] = "请识别图片中的文字"
            data["response"] = text
            data["image_path"] = img_path
            data_list.append(data)
        write_jsonl(data_list, "train.jsonl")

        data_list = []
        for line in lines[90000:]:
            img_name, text = line.rstrip().split("\t")
            img_path = os.path.join(img_dir, img_name)
            data = {}
            data["query"] = "请识别图片中的文字"
            data["response"] = text
            data["image_path"] = img_path
            data_list.append(data)
        write_jsonl(data_list, "val.jsonl")
        print("done")

处理后的数据如下,示例:

{"query": "请识别图片中的文字", "response": "在2日内到有效", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090008.jpg"}
{"query": "请识别图片中的文字", "response": "车服务公司", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090009.jpg"}
{"query": "请识别图片中的文字", "response": "宗派排次", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090010.jpg"}
{"query": "请识别图片中的文字", "response": "增加金属蛋白酶,有助于异位组织的侵蚀", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090011.jpg"}
{"query": "请识别图片中的文字", "response": "学历要求", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090012.jpg"}
{"query": "请识别图片中的文字", "response": "防御", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090013.jpg"}
{"query": "请识别图片中的文字", "response": "等:¥476.0", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090014.jpg"}
{"query": "请识别图片中的文字", "response": "余443张", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090015.jpg"}
{"query": "请识别图片中的文字", "response": "中国", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090016.jpg"}
{"query": "请识别图片中的文字", "response": "般10%以下", "image_path": "/home/xxx/xxx/dataset/ocr_reg_small_dataset/data/TrainImages/Train_090017.jpg"}

三. 微调模型

MAX_PIXELS=1003520 \
CUDA_VISIBLE_DEVICES=0 \
swift sft \
    --model Qwen/Qwen2.5-VL-7B-Instruct \
    --dataset /home/xxx/xxx/dataset/ocr_reg_small_dataset/data/train.jsonl \
    --train_type lora \
    --torch_dtype bfloat16 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-4 \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --freeze_vit true \
    --gradient_accumulation_steps 16 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 5 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output_ocr \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4

模型训练中,正常收敛,如下图所示,内存占用18G左右:

{'loss': 4.36318054, 'token_acc': 0.32325581, 'grad_norm': 5.01398468, 'learning_rate': 9.725e-05, 'memory(GiB)': 16.51, 'train_speed(iter/s)': 0.178871, 'epoch': 0.15, 'global_step/max_steps': '840/5568', 'percentage': '15.09%', 'elapsed_time': '1h 18m 15s', 'remaining_time': '7h 20m 31s'}
{'loss': 4.03473396, 'token_acc': 0.34009009, 'grad_norm': 4.07742596, 'learning_rate': 9.72e-05, 'memory(GiB)': 16.51, 'train_speed(iter/s)': 0.179055, 'epoch': 0.15, 'global_step/max_steps': '845/5568', 'percentage': '15.18%', 'elapsed_time': '1h 18m 38s', 'remaining_time': '7h 19m 35s'}
{'loss': 4.13988152, 'token_acc': 0.3490566, 'grad_norm': 3.48686051, 'learning_rate': 9.715e-05, 'memory(GiB)': 16.51, 'train_speed(iter/s)': 0.179242, 'epoch': 0.15, 'global_step/max_steps': '850/5568', 'percentage': '15.27%', 'elapsed_time': '1h 19m 1s', 'remaining_time': '7h 18m 40s'}
Train:  15%|███████████████████████▏                                                                                                                                | 850/5568 [1:19:01<5:59:45,  4.58s/it]

将训练好的模型与loar融合

# checkpoint-5568-merged  融合会生成这样一个文件夹,和Qwen2.5-vl-7b的使用方式完全相同 
# 这里`--adapters`需要替换生成训练生成的最后一个检查点文件夹。 由于adapters文件夹中包含了训练的参数文件因此,不需要额外指定`--model`:
CUDA_VISIBLE_DEVICES=0 swift export \
      --adapters  ./output_ocr/vx-xxx/checkpoint-5568 \
      --merge_lora true             

四. 模型测试

# pt推理
NPROC_PER_NODE=1 MAX_PIXELS=1003520 swift infer \
    --ckpt_dir ./output_ocr/vx-xxx/checkpoint-5568-merged \
    --max_new_tokens 300 \
    --temperature 0 \
    --val_dataset val_dataset.jsonl \
    --result_path output_5568.jsonl \
    --max_batch_size 1 \
    --infer_backend pt

参考:
ms-swift
多模态模型实践——swift3框架使用
Qwen2.5 VL! 重要的模型说三遍!

相关文章:

  • 基于香橙派 KunpengPro学习CANN(3)——pytorch 模型迁移
  • JavaScript基础-获取元素
  • Shell脚本中的弱治简写
  • 平衡树的模拟实现
  • Golang开发
  • ROS合集(一)ROS常见命令及其用途
  • springboot多种生产打包方式教程
  • 循环神经网络中用到的概率论知识
  • YOLOv8 OBB 旋转目标检测模型详解与实践
  • 59. 螺旋矩阵 II
  • 深度洞察:特种设备作业考试的核心要点与备考策略
  • 蓝桥杯 修剪灌木
  • opencv初步学习——图像处理3
  • LeetCode BFS层序遍历树
  • 工作记录 2017-02-04
  • 【css酷炫效果】纯CSS实现照片堆叠效果
  • 2025年通信安全员考试题库及答案
  • xxl-job 执行器端服务器的简单搭建
  • OneCyber 平台
  • 杨校老师课堂之编程入门与软件安装【图文笔记】
  • 上海高院与上海妇联签协议,建立反家暴常态化联动协作机制
  • 证监会强化上市公司募资监管七要点:超募资金不得补流、还贷
  • 李强:把做强国内大循环作为推动经济行稳致远的战略之举
  • 魔都眼|锦江乐园摩天轮“换代”开拆,新摩天轮暂定118米
  • 刘晓庆被实名举报涉嫌偷税漏税,税务部门启动调查
  • 习近平复信中国丹麦商会负责人