使用 nanoVLM 训练一个 VLM
使用 nanoVLM 训练一个 VLM
环境准备
克隆 GitHub 仓库
git clone https://github.com/huggingface/nanoVLM.git
使用 uv
# 安装 uv 包管理器
curl -LsSf https://astral.sh/uv/install.sh | sh# 创建环境
uv init --bare --python 3.12
uv sync --python 3.12
source .venv/bin/activate
uv add torch numpy torchvision pillow datasets huggingface-hub transformers wandb
# Optional: for lmms-eval integration you have to install it from source, see section 'Evaluation with lmms-eval'
或者 conda
conda create --name nanovlm python=3.12
conda activate nanovlm
pip install torch numpy torchvision pillow datasets huggingface-hub transformers wandb
pip install matplotlib ipywidges -i https://mirrors.aliyun.com/pypi/simple/
### Evaluation with lmms-eval
pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
Coding
设置 huggingface 下载镜像
# export PYTHONPATH=./
export HF_ENDPOINT=https://hf-mirror.com
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
登录 huggingface
from huggingface_hub import notebook_login
notebook_login()
设置模型名称
hf_model_name = "shizidushu/nanoVLM"
导入相关的库
# nanoVLM Imports (please check out the implementations in detail, that's where all the interesting stuff is!)
from data.collators import VQACollator, MMStarCollator
from data.datasets import MMStarDataset, VQADataset
from data.processors import get_image_processor, get_tokenizer
from models.vision_language_model import VisionLanguageModel
import models.utils as utils# Libraries
import math
import time
import torch
from tqdm import tqdm
import torch.optim as optim
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets#Otherwise, the tokenizer will through a warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"if torch.cuda.is_available():device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():device = "mps"
else:device = "cpu"
print(f"Using device: {device}")torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
定义 dataloader 获取函数
def get_dataloaders(train_cfg, vlm_cfg):# Create datasetsimage_processor = get_image_processor(vlm_cfg.vit_img_size)tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template)# Load and combine all training datasetscombined_train_data = []for dataset_name in train_cfg.train_dataset_name:train_ds = load_dataset(train_cfg.train_dataset_path, dataset_name)combined_train_data.append(train_ds['train'])train_ds = concatenate_datasets(combined_train_data)test_ds = load_dataset(train_cfg.test_dataset_path)train_ds = train_ds.shuffle(seed=0) # Shuffle the training dataset, so train and val get equal contributions from all concatenated datasets# Apply cutoff if specifiedif train_cfg.data_cutoff_idx is None:total_samples = len(train_ds) # Use the entire datasetelse:total_samples = min(len(train_ds), train_cfg.data_cutoff_idx)val_size = int(total_samples * train_cfg.val_ratio)train_size = total_samples - val_sizetrain_dataset = VQADataset(train_ds.select(range(train_size)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)val_dataset = VQADataset(train_ds.select(range(train_size, total_samples)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)test_dataset = MMStarDataset(test_ds['val'], tokenizer, image_processor, vlm_cfg.mp_image_token_length)# Create collatorsvqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length)mmstar_collator = MMStarCollator(tokenizer)# Create dataloaderstrain_loader = DataLoader(train_dataset,batch_size=train_cfg.batch_size,shuffle=True,collate_fn=vqa_collator,num_workers=2,pin_memory=True,drop_last=True,)val_loader = DataLoader(val_dataset,batch_size=train_cfg.batch_size,shuffle=False,collate_fn=vqa_collator,num_workers=2,pin_memory=True,drop_last=True,)test_loader = DataLoader(test_dataset, batch_size=train_cfg.mmstar_batch_size, shuffle=False, collate_fn=mmstar_collator,pin_memory=True,)return train_loader, val_loader, test_loader
其中
data_cutoff_idx
用于限制全部样本数量(train + val)
编写测试函数
def test_mmstar(model, tokenizer, test_loader, device):# Go through MMStar and count how many answers we get rightmodel.eval()total_examples = 0correct_predictions = 0with torch.no_grad():for batch in test_loader:image = batch['images']input_ids = batch['input_ids'].to(device)labels = batch['labels'].to(device)attention_mask = batch['attention_mask'].to(device)correct_answer = tokenizer.batch_decode(labels, skip_special_tokens=True)gen = model.generate(input_ids, image, attention_mask, greedy=True)model_output = tokenizer.batch_decode(gen, skip_special_tokens=True)is_correct = utils.check_multiple_choice_with_regex(model_output, correct_answer)total_examples += len(is_correct)if is_correct:correct_predictions += sum(is_correct)accuracy = correct_predictions / total_examples if total_examples > 0 else 0model.train()return accuracy
准备训练循环
def get_lr(it, max_lr, max_steps):min_lr = max_lr * 0.1warmup_steps = max_steps * 0.03# 1) linear warmup for warmup_iters stepsif it < warmup_steps:return max_lr * (it+1) / warmup_steps# 2) if it > lr_decay_iters, return min learning rateif it > max_steps:return min_lr# 3) in between, use cosine decay down to min learning ratedecay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)assert 0 <= decay_ratio <= 1coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0return min_lr + coeff * (max_lr - min_lr)def train(train_cfg, vlm_cfg):train_loader, val_loader, test_loader = get_dataloaders(train_cfg, vlm_cfg)tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens, vlm_cfg.lm_chat_template)# Initialize modelif train_cfg.resume_from_vlm_checkpoint:model = VisionLanguageModel.from_pretrained(vlm_cfg.vlm_checkpoint_path)else:model = VisionLanguageModel(vlm_cfg)print(f"nanoVLM initialized with {sum(p.numel() for p in model.parameters()):,} parameters")print(f"Training summary: {len(train_loader.dataset)} samples, {len(train_loader)} batches/epoch, batch size {train_cfg.batch_size}")# Define optimizer groupsparam_groups = [{'params': model.MP.parameters(), 'lr': train_cfg.lr_mp},{'params': list(model.decoder.parameters()) + list(model.vision_encoder.parameters()), 'lr': train_cfg.lr_backbones}]optimizer = optim.AdamW(param_groups)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)if train_cfg.compile:model = torch.compile(model)epoch_times = []batch_losses = []val_losses = []val_plot_steps = []best_accuracy = 0global_step = 0for epoch in range(train_cfg.epochs):epoch_start_time = time.time()model.train()total_train_loss = 0total_tokens_processed = 0for batch in tqdm(train_loader):batch_start_time = time.time()images = batch["images"]input_ids = batch["input_ids"].to(device)labels = batch["labels"].to(device)attention_mask = batch["attention_mask"].to(device)optimizer.zero_grad()with torch.autocast(device_type='cuda', dtype=torch.float16): # Mixed precision training_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)loss.backward()adj_lr_mp = get_lr(global_step, train_cfg.lr_mp, len(train_loader) * train_cfg.epochs)adj_lr_backbones = get_lr(global_step, train_cfg.lr_backbones, len(train_loader) * train_cfg.epochs)optimizer.param_groups[0]['lr'] = adj_lr_mpoptimizer.param_groups[1]['lr'] = adj_lr_backbonesoptimizer.step()batch_loss = loss.item()total_train_loss += batch_lossbatch_losses.append(batch_loss)num_tokens = torch.sum(attention_mask).item() # Sum of attention mask gives number of tokenstotal_tokens_processed += num_tokensbatch_end_time = time.time()batch_duration = batch_end_time - batch_start_timetokens_per_second = num_tokens / batch_durationif global_step % 5 == 0:model.eval()torch.cuda.empty_cache() # Clear GPU memorywith torch.no_grad():total_val_loss = 0for batch in val_loader:images = batch["images"]input_ids = batch["input_ids"].to(device)labels = batch["labels"].to(device)attention_mask = batch["attention_mask"].to(device)with torch.amp.autocast(device_type='cuda', dtype=torch.float16):_, loss = model(input_ids, images, attention_mask=attention_mask, targets=labels)total_val_loss += loss.item()avg_val_loss = total_val_loss / len(val_loader)val_losses.append(avg_val_loss)val_plot_steps.append(global_step)epoch_accuracy = 0if train_cfg.eval_in_epochs:epoch_accuracy = test_mmstar(model, tokenizer, test_loader, device)if epoch_accuracy > best_accuracy:best_accuracy = epoch_accuracymodel.save_pretrained(save_directory=vlm_cfg.vlm_checkpoint_path)print(f"\nStep: {global_step}, Loss: {batch_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Tokens/s: {tokens_per_second:.2f}, Accuracy: {epoch_accuracy:.4f}")model.train()global_step += 1avg_train_loss = total_train_loss / len(train_loader)epoch_end_time = time.time()epoch_duration = epoch_end_time - epoch_start_timeepoch_times.append(epoch_duration)epoch_tokens_per_second = total_tokens_processed / epoch_durationprint(f"Epoch {epoch+1}/{train_cfg.epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Time: {epoch_duration:.2f}s | T/s: {epoch_tokens_per_second:.2f}")# Summary Statisticsif not train_cfg.eval_in_epochs:model.save_pretrained(save_directory=vlm_cfg.vlm_checkpoint_path)model.push_to_hub(hf_model_name)avg_epoch_time = sum(epoch_times) / len(epoch_times)total_training_time = sum(epoch_times)total_samples_processed = len(train_loader.dataset) * train_cfg.epochsavg_time_per_sample = total_training_time / total_samples_processedprint(f"Average time per epoch: {avg_epoch_time:.2f}s")print(f"Average time per sample: {avg_time_per_sample:.4f}s")plt.plot(batch_losses, label='Train Loss')plt.plot(val_plot_steps, val_losses, label='Val Loss')plt.xlabel('Batch')plt.ylabel('Loss')plt.title('Loss Curve')plt.grid(True)plt.legend()plt.show()# With this code you can test the accuracy of the model on the MMStar dataset# But if you only train with few samples, the accuracy will be very low# print("Testing MMStar Accuracy:")# accuracy = test_mmstar(model, tokenizer, test_loader, device)# print(f"MMStar Accuracy: {accuracy:.4f}")
准备 Configs
@dataclass
class VLMConfig:vit_hidden_dim: int = 768vit_inter_dim: int = 4 * vit_hidden_dimvit_patch_size: int = 16vit_img_size: int = 224vit_n_heads: int = 12vit_dropout: float = 0.0vit_n_blocks: int = 12vit_ln_eps: float = 1e-6vit_cls_flag: bool = Falsevit_model_type: str = 'google/siglip-base-patch16-224'lm_hidden_dim: int = 576lm_inter_dim: int = 1536lm_rms_eps: float = 1e-5lm_re_base: int = 100000lm_max_position_embeddings: int = 8192lm_base_vocab_size: int = 49152extra_token_amount: int = 1 # Number of extra tokens for the VLM (image start, image end, image token)lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # Not a great way to do this, but it works for now (vlm_extra_tokens cannot be a dict, since this is mutable, and a Field has no len() function)lm_n_heads: int = 9lm_n_kv_heads: int = 3lm_dropout: float = 0.0lm_n_blocks: int = 30lm_attn_scaling: float = 1.0lm_eos_token_id: int = 0lm_max_length: int = 128lm_use_tokens: bool = False # Decide if the LM expects tokens or embeddings as input (if using as a backbone for the VLM, set to False)lm_tie_weights: bool = True # Decide if you want to tie the LM Head weight to the token embedding weightslm_model_type: str = 'HuggingFaceTB/SmolLM2-135M'lm_tokenizer: str = 'HuggingFaceTB/cosmo2-tokenizer'lm_chat_template: str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"mp_pixel_shuffle_factor: int = 2mp_image_token_length: int = 49vlm_extra_tokens: dict[str, str] = field(default_factory=lambda: {"image_token": "<|image|>"})#, "boi_token": "<|image_start|>", "eoi_token": "<|image_end|>"})vlm_load_backbone_weights: bool = Truevlm_checkpoint_path: str = 'checkpoints'hf_repo_name: str = 'nanoVLM'@dataclass
class TrainConfig:lr_mp: float = 1e-3lr_backbones: float = 5e-5val_ratio: float = 0.2compile: bool = Falsedata_cutoff_idx: int = 1024 # Let's only use a small subset of the data at first, otherwise it takes very long to see anything :Dbatch_size: int = 12mmstar_batch_size: int = 12epochs: int = 5eval_in_epochs: bool = False # Deactivating this in colab, because it would evaluate 1500 samples of MMStar every time otherwiseresume_from_vlm_checkpoint: bool = False # Indicate if the training should be resumed from a checkpoint of the whole VLM or you want to start from scratchtrain_dataset_path: str = 'HuggingFaceM4/the_cauldron'train_dataset_name: tuple[str, ...] = ("tqa", "vsr") #All options; ("ai2d", "aokvqa", "chart2text", "chartqa", "clevr", "cocoqa", "datikz", "diagram_image_to_text", "docvqa", "dvqa", "figureqa", "finqa", "geomverse", "hateful_memes", "hitab", "iam", "iconqa", "infographic_vqa", "intergps", "localized_narratives", "mapqa", "multihiertt", "ocrvqa", "plotqa", "raven", "rendered_text", "robut_sqa", "robut_wikisql", "robut_wtq", "scienceqa", "screen2words", "st_vqa", "tabmwp", "tallyqa", "tat_qa", "textcaps", "textvqa", "tqa", "vistext", "visual7w", "visualmrc", "vqarad", "vqav2", "vsr", "websight") # "clevr_math", "okvqa", "spot_the_diff", "nlvr2", "mimic_cgd",test_dataset_path: str = "Lin-Chen/MMStar"
执行训练
vlm_cfg = VLMConfig()
train_cfg = TrainConfig()
train(train_cfg, vlm_cfg)