当前位置: 首页 > news >正文

使用 nanoVLM 训练一个 VLM

使用 nanoVLM 训练一个 VLM

环境准备

克隆 GitHub 仓库

git clone https://github.com/huggingface/nanoVLM.git

使用 uv

# 安装 uv 包管理器
curl -LsSf https://astral.sh/uv/install.sh | sh# 创建环境
uv init --bare --python 3.12
uv sync --python 3.12
source .venv/bin/activate
uv add torch numpy torchvision pillow datasets huggingface-hub transformers wandb
# Optional: for lmms-eval integration you have to install it from source, see section 'Evaluation with lmms-eval'

或者 conda

conda create --name nanovlm python=3.12
conda activate nanovlm
pip install torch numpy torchvision pillow datasets huggingface-hub transformers wandb
pip install matplotlib ipywidges -i https://mirrors.aliyun.com/pypi/simple/
### Evaluation with lmms-eval
pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git

Coding

设置 huggingface 下载镜像

# export PYTHONPATH=./
export HF_ENDPOINT=https://hf-mirror.com
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

登录 huggingface

from huggingface_hub import notebook_login
notebook_login()

设置模型名称

hf_model_name = "shizidushu/nanoVLM"

导入相关的库

# nanoVLM Imports (please check out the implementations in detail, that's where all the interesting stuff is!)
from data.collators import VQACollator, MMStarCollator
from data.datasets import MMStarDataset, VQADataset
from data.processors import get_image_processor, get_tokenizer
from models.vision_language_model import VisionLanguageModel
import models.utils as utils# Libraries
import math
import time
import torch
from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets#Otherwise, the tokenizer will through a warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"if torch.cuda.is_available():device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():device = "mps"
else:device = "cpu"
print(f"Using device: {device}")torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

定义 dataloader 获取函数

def get_dataloaders(train_cfg, vlm_cfg):# Create datasetsimage_processor = get_image_processor(vlm_cfg.vit_img_size)tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template)# Load and combine all training datasetscombined_train_data = []for dataset_name in train_cfg.train_dataset_name:train_ds = load_dataset(train_cfg.train_dataset_path, dataset_name)combined_train_data.append(train_ds['train'])train_ds = concatenate_datasets(combined_train_data)test_ds = load_dataset(train_cfg.test_dataset_path)train_ds = train_ds.shuffle(seed=0) # Shuffle the training dataset, so train and val get equal contributions from all concatenated datasets# Apply cutoff if specifiedif train_cfg.data_cutoff_idx is None:total_samples = len(train_ds)  # Use the entire datasetelse:total_samples = min(len(train_ds), train_cfg.data_cutoff_idx)val_size = int(total_samples * train_cfg.val_ratio)train_size = total_samples - val_sizetrain_dataset = VQADataset(train_ds.select(range(train_size)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)val_dataset = VQADataset(train_ds.select(range(train_size, total_samples)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)test_dataset = MMStarDataset(test_ds['val'], tokenizer, image_processor, vlm_cfg.mp_image_token_length)# Create collatorsvqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length)mmstar_collator = MMStarCollator(tokenizer)# Create dataloaderstrain_loader = DataLoader(train_dataset,batch_size=train_cfg.batch_size,shuffle=True,collate_fn=vqa_collator,num_workers=2,pin_memory=True,drop_last=True,)val_loader = DataLoader(val_dataset,batch_size=train_cfg.batch_size,shuffle=False,collate_fn=vqa_collator,num_workers=2,pin_memory=True,drop_last=True,)test_loader = DataLoader(test_dataset, batch_size=train_cfg.mmstar_batch_size, shuffle=False, collate_fn=mmstar_collator,pin_memory=True,)return train_loader, val_loader, test_loader

其中

  • data_cutoff_idx 用于限制全部样本数量(train + val)

编写测试函数

def test_mmstar(model, tokenizer, test_loader, device):# Go through MMStar and count how many answers we get rightmodel.eval()total_examples = 0correct_predictions = 0with torch.no_grad():for batch in test_loader:image = batch['images']input_ids = batch['input_ids'].to(device)labels = batch['labels'].to(device)attention_mask = batch['attention_mask'].to(device)correct_answer = tokenizer.batch_decode(labels, skip_special_tokens=True)gen = model.generate(input_ids, image, attention_mask, greedy=True)model_output = tokenizer.batch_decode(gen, skip_special_tokens=True)is_correct = utils.check_multiple_choice_with_regex(model_output, correct_answer)total_examples += len(is_correct)if is_correct:correct_predictions += sum(is_correct)accuracy = correct_predictions / total_examples if total_examples > 0 else 0model.train()return accuracy

准备训练循环

def get_lr(it, max_lr, max_steps):min_lr = max_lr * 0.1warmup_steps = max_steps * 0.03# 1) linear warmup for warmup_iters stepsif it < warmup_steps:return max_lr * (it+1) / warmup_steps# 2) if it > lr_decay_iters, return min learning rateif it > max_steps:return min_lr# 3) in between, use cosine decay down to min learning ratedecay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)assert 0 <= decay_ratio <= 1coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0return min_lr + coeff * (max_lr - min_lr)def train(train_cfg, vlm_cfg):train_loader, val_loader, test_loader = get_dataloaders(train_cfg, vlm_cfg)tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template)# Initialize modelif train_cfg.resume_from_vlm_checkpoint:model = VisionLanguageModel.from_pretrained(vlm_cfg.vlm_checkpoint_path)else:model = VisionLanguageModel(vlm_cfg)print(f"nanoVLM initialized with {sum(p.numel() for p in model.parameters()):,} parameters")print(f"Training summary: {len(train_loader.dataset)} samples, {len(train_loader)} batches/epoch, batch size {train_cfg.batch_size}")# Define optimizer groupsparam_groups = [{'params': model.MP.parameters(), 'lr': train_cfg.lr_mp},{'params': list(model.decoder.parameters()) + list(model.vision_encoder.parameters()), 'lr': train_cfg.lr_backbones}]optimizer = optim.AdamW(param_groups)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)if train_cfg.compile:model = torch.compile(model)epoch_times = []batch_losses = []val_losses = []val_plot_steps = []best_accuracy = 0global_step = 0for epoch in range(train_cfg.epochs):epoch_start_time = time.time()model.train()total_train_loss = 0total_tokens_processed = 0for batch in tqdm(train_loader):batch_start_time = time.time()images = batch["images"]input_ids = batch["input_ids"].to(device)labels = batch["labels"].to(device)attention_mask = batch["attention_mask"].to(device)optimizer.zero_grad()with torch.autocast(device_type='cuda', dtype=torch.float16): # Mixed precision training_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)loss.backward()adj_lr_mp = get_lr(global_step, train_cfg.lr_mp, len(train_loader) * train_cfg.epochs)adj_lr_backbones = get_lr(global_step, train_cfg.lr_backbones, len(train_loader) * train_cfg.epochs)optimizer.param_groups[0]['lr'] = adj_lr_mpoptimizer.param_groups[1]['lr'] = adj_lr_backbonesoptimizer.step()batch_loss = loss.item()total_train_loss += batch_lossbatch_losses.append(batch_loss)num_tokens = torch.sum(attention_mask).item() # Sum of attention mask gives number of tokenstotal_tokens_processed += num_tokensbatch_end_time = time.time()batch_duration = batch_end_time - batch_start_timetokens_per_second = num_tokens / batch_durationif global_step % 5 == 0:model.eval()torch.cuda.empty_cache()  # Clear GPU memorywith torch.no_grad():total_val_loss = 0for batch in val_loader:images = batch["images"]input_ids = batch["input_ids"].to(device)labels = batch["labels"].to(device)attention_mask = batch["attention_mask"].to(device)with torch.amp.autocast(device_type='cuda', dtype=torch.float16):_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)total_val_loss += loss.item()avg_val_loss = total_val_loss / len(val_loader)val_losses.append(avg_val_loss)val_plot_steps.append(global_step)epoch_accuracy = 0if train_cfg.eval_in_epochs:epoch_accuracy = test_mmstar(model, tokenizer, test_loader, device)if epoch_accuracy > best_accuracy:best_accuracy = epoch_accuracymodel.save_pretrained(save_directory=vlm_cfg.vlm_checkpoint_path)print(f"\nStep: {global_step}, Loss: {batch_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Tokens/s: {tokens_per_second:.2f}, Accuracy: {epoch_accuracy:.4f}")model.train()global_step += 1avg_train_loss = total_train_loss / len(train_loader)epoch_end_time = time.time()epoch_duration = epoch_end_time - epoch_start_timeepoch_times.append(epoch_duration)epoch_tokens_per_second = total_tokens_processed / epoch_durationprint(f"Epoch {epoch+1}/{train_cfg.epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Time: {epoch_duration:.2f}s | T/s: {epoch_tokens_per_second:.2f}")# Summary Statisticsif not train_cfg.eval_in_epochs:model.save_pretrained(save_directory=vlm_cfg.vlm_checkpoint_path)model.push_to_hub(hf_model_name)avg_epoch_time = sum(epoch_times) / len(epoch_times)total_training_time = sum(epoch_times)total_samples_processed = len(train_loader.dataset) * train_cfg.epochsavg_time_per_sample = total_training_time / total_samples_processedprint(f"Average time per epoch: {avg_epoch_time:.2f}s")print(f"Average time per sample: {avg_time_per_sample:.4f}s")plt.plot(batch_losses, label='Train Loss')plt.plot(val_plot_steps, val_losses, label='Val Loss')plt.xlabel('Batch')plt.ylabel('Loss')plt.title('Loss Curve')plt.grid(True)plt.legend()plt.show()# With this code you can test the accuracy of the model on the MMStar dataset# But if you only train with few samples, the accuracy will be very low# print("Testing MMStar Accuracy:")# accuracy = test_mmstar(model, tokenizer, test_loader, device)# print(f"MMStar Accuracy: {accuracy:.4f}")

准备 Configs

@dataclass
class VLMConfig:vit_hidden_dim: int = 768vit_inter_dim: int = 4 * vit_hidden_dimvit_patch_size: int = 16vit_img_size: int = 224vit_n_heads: int = 12vit_dropout: float = 0.0vit_n_blocks: int = 12vit_ln_eps: float = 1e-6vit_cls_flag: bool = Falsevit_model_type: str = 'google/siglip-base-patch16-224'lm_hidden_dim: int = 576lm_inter_dim: int = 1536lm_rms_eps: float = 1e-5lm_re_base: int = 100000lm_max_position_embeddings: int = 8192lm_base_vocab_size: int = 49152extra_token_amount: int = 1  # Number of extra tokens for the VLM (image start, image end, image token)lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # Not a great way to do this, but it works for now (vlm_extra_tokens cannot be a dict, since this is mutable, and a Field has no len() function)lm_n_heads: int = 9lm_n_kv_heads: int = 3lm_dropout: float = 0.0lm_n_blocks: int = 30lm_attn_scaling: float = 1.0lm_eos_token_id: int = 0lm_max_length: int = 128lm_use_tokens: bool = False # Decide if the LM expects tokens or embeddings as input (if using as a backbone for the VLM, set to False)lm_tie_weights: bool = True # Decide if you want to tie the LM Head weight to the token embedding weightslm_model_type: str = 'HuggingFaceTB/SmolLM2-135M'lm_tokenizer: str = 'HuggingFaceTB/cosmo2-tokenizer'lm_chat_template: str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"mp_pixel_shuffle_factor: int = 2mp_image_token_length: int = 49vlm_extra_tokens: dict[str, str] = field(default_factory=lambda: {"image_token": "<|image|>"})#, "boi_token": "<|image_start|>", "eoi_token": "<|image_end|>"})vlm_load_backbone_weights: bool = Truevlm_checkpoint_path: str = 'checkpoints'hf_repo_name: str = 'nanoVLM'@dataclass
class TrainConfig:lr_mp: float = 1e-3lr_backbones: float = 5e-5val_ratio: float = 0.2compile: bool = Falsedata_cutoff_idx: int = 1024 # Let's only use a small subset of the data at first, otherwise it takes very long to see anything :Dbatch_size: int = 12mmstar_batch_size: int = 12epochs: int = 5eval_in_epochs: bool = False # Deactivating this in colab, because it would evaluate 1500 samples of MMStar every time otherwiseresume_from_vlm_checkpoint: bool = False # Indicate if the training should be resumed from a checkpoint of the whole VLM or you want to start from scratchtrain_dataset_path: str = 'HuggingFaceM4/the_cauldron'train_dataset_name: tuple[str, ...] = ("tqa", "vsr") #All options; ("ai2d", "aokvqa", "chart2text", "chartqa", "clevr", "cocoqa", "datikz", "diagram_image_to_text", "docvqa", "dvqa", "figureqa", "finqa", "geomverse", "hateful_memes", "hitab", "iam", "iconqa", "infographic_vqa", "intergps", "localized_narratives", "mapqa", "multihiertt", "ocrvqa", "plotqa", "raven", "rendered_text", "robut_sqa", "robut_wikisql", "robut_wtq", "scienceqa", "screen2words", "st_vqa", "tabmwp", "tallyqa", "tat_qa", "textcaps", "textvqa", "tqa", "vistext", "visual7w", "visualmrc", "vqarad", "vqav2", "vsr", "websight") # "clevr_math", "okvqa", "spot_the_diff", "nlvr2", "mimic_cgd",test_dataset_path: str = "Lin-Chen/MMStar"

执行训练

vlm_cfg = VLMConfig()
train_cfg = TrainConfig()
train(train_cfg, vlm_cfg)

文章转载自:

http://sTVhhmD7.ggqcg.cn
http://Yte7vH4z.ggqcg.cn
http://2C5rHFUG.ggqcg.cn
http://kwbEJQDl.ggqcg.cn
http://D0HUeQFT.ggqcg.cn
http://DUmO5n9m.ggqcg.cn
http://2z27GNwU.ggqcg.cn
http://dL2rOr5Z.ggqcg.cn
http://T5DkGEcz.ggqcg.cn
http://J2IIfXVL.ggqcg.cn
http://KqczDkTZ.ggqcg.cn
http://oeAEVze5.ggqcg.cn
http://5TmQwBxj.ggqcg.cn
http://Bf6XlS2U.ggqcg.cn
http://FMAzfLzm.ggqcg.cn
http://z4lb4jVn.ggqcg.cn
http://PgWCXMPz.ggqcg.cn
http://zEW9FC7f.ggqcg.cn
http://vlX3onwm.ggqcg.cn
http://Y2ZrOpyN.ggqcg.cn
http://clYL2AgM.ggqcg.cn
http://gNzPJJ1G.ggqcg.cn
http://DoIRwpTX.ggqcg.cn
http://PfdlNd6T.ggqcg.cn
http://6tQnOU41.ggqcg.cn
http://OJh6SNYk.ggqcg.cn
http://d3hkNS64.ggqcg.cn
http://fEcfLxMt.ggqcg.cn
http://sQc0zMF6.ggqcg.cn
http://QJ9NJNbp.ggqcg.cn
http://www.dtcms.com/a/380999.html

相关文章:

  • 2025年- H135-Lc209. 长度最小的子数组(字符串)--Java版
  • 数据库建表练习
  • 使用tree命令导出文件夹/文件的目录树(linux)
  • 【SQL】指定日期的产品价格
  • 在WPF项目中使用阿里图标库iconfont
  • 新能源知识库(91)《新型储能规模化行动方案》精华摘引
  • 51c自动驾驶~合集29
  • Arbess V2.0.7版本发布,支持Docker/主机蓝绿部署任务,支持Gradle构建、Agent运行策略
  • 中科米堆CASAIM自动化三维检测系统-支持批量测量工件三维尺寸
  • 【学习K230-例程19】GT6700-TCP-Client
  • Java链表
  • 【PostgreSQL内核学习:表达式】
  • 步骤流程中日志记录方案(类aop)
  • React.memo 小练习题 + 参考答案
  • Java 的即时编译器(JIT)优化编译探测技术
  • 《计算机网络安全》实验报告一 现代网络安全挑战 拒绝服务与分布式拒绝服务攻击的演变与防御策略(4)
  • 综合体EMS微电网能效管理系统解决方案
  • ARM2.(汇编语言)
  • 从“插件化“到“智能化“:解密Semantic Kernel中Microsoft Graph的架构设计艺术
  • TDengine 特殊函数 MODE() 用户手册
  • 导购类电商平台的安全架构设计:防刷单与反作弊系统实现
  • 阿里云可观测 2025 年 8 月产品动态
  • 阿里云监控使用
  • 九识智能与北控北斗合作研发的L4级燃气超微量高精准泄漏检测无人车闪耀服贸会,守护城市安全
  • vulhub漏洞复现-redis-4-unacc (redis未授权访问)
  • 数据库分库分表是考虑ShardingSphere 还是Mycat?
  • CSP认证练习题目推荐 (3)
  • R geo 然后读取数据的时候 make.names(vnames, unique = TRUE): invalid multibyte string 9
  • Linux:线程封装
  • 电动指甲刀技术方案概述