当前位置: 首页 > news >正文

TrOCR模型微调

参考连接【Transformers-Tutorials/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb】

1.根据任务类型构建数据集

import torch
from torch.utils.data import Dataset
from PIL import Imageclass ORCDataset(Dataset):def __init__(self, root_dir, df, processor, max_target_length=256):self.root_dir = root_dirself.df = dfself.processor = processorself.max_target_length = max_target_lengthdef __len__(self):return len(self.df)def __getitem__(self, idx):# get file name + text file_name = self.df['file_name'][idx]text = self.df['text'][idx]# prepare image (i.e. resize + normalize)image = Image.open(self.root_dir + file_name).convert("RGB")pixel_values = self.processor(image, return_tensors="pt").pixel_values# add labels (input_ids) by encoding the textlabels = self.processor.tokenizer(text, padding="max_length", max_length=self.max_target_length).input_ids# important: make sure that PAD tokens are ignored by the loss functionlabels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}return encodingcache_dir = "./pretrain"
# 你可以按照这个方法先缓存到本地
# from transformers import VisionEncoderDecoderModel
# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed",cache_dir='./trocr-base-printed')
# 或者你直接去官网下,然后都放一个文件夹processor = TrOCRProcessor.from_pretrained(cache_dir)
train_dataset = ORCDataset(root_dir='',df=train_df,processor=processor)
eval_dataset = ORCDataset(root_dir='',df=test_df,processor=processor)
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))from torch.utils.data import DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=16)

2.加载模型

from transformers import VisionEncoderDecoderModel
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = VisionEncoderDecoderModel.from_pretrained(cache_dir)
model.to(device)
print(model.encoder)# 这里可以设置一些分层学习率的操作
#decoder_param_id = list(map(id,model.decoder.parameters()))
# encoder_params = filter(lambda p: id(p) not in decoder_param_id, model.parameters())
# encoder_params = model.encoder.parameters()
# decoder_params = model.decoder.parameters()# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

3.加载评价指标

from datasets import load_metriccer_metric = load_metric("cer")
# import datasets
# 加载本地 CER 度量
# cer_metric = datasets.load_metric("./cer.py")def compute_cer(pred_ids, label_ids):pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)label_ids[label_ids == -100] = processor.tokenizer.pad_token_idlabel_str = processor.batch_decode(label_ids, skip_special_tokens=True)cer = cer_metric.compute(predictions=pred_str, references=label_str)return cer

4.原生Pytorch训练流程

from tqdm import tqdm
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(100):  # loop over the dataset multiple times# trainmodel.train()train_loss = 0.0for batch in tqdm(train_dataloader):# get the inputsfor k,v in batch.items():batch[k] = v.to(device)# forward + backward + optimizeoutputs = model(**batch)loss = outputs.lossloss.backward()optimizer.step()optimizer.zero_grad()train_loss += loss.item()print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))# evaluatemodel.eval()valid_cer = 0.0with torch.no_grad():for batch in tqdm(eval_dataloader):# run batch generationoutputs = model.generate(batch["pixel_values"].to(device))# compute metricscer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])valid_cer += cer print("Validation CER:", valid_cer / len(eval_dataloader))model.save_pretrained(".")

相关文章:

  • LDStega论文阅读笔记
  • 阿里云可观测 2025 年 5 月产品动态
  • 【每日likou】704. 二分查找 27. 移除元素 977.有序数组的平方
  • docker-compose搭建eureka-server和zipkin
  • asio之静态互斥量
  • ubuntu22 arm 编译安装input leap
  • 20250611让NanoPi NEO core开发板在Ubuntu core16.04系统下开机自启动的时候拉高GPIOG8
  • NumPy 2.x 完全指南【二十五】记录数组
  • 建站新手:我与SiteServerCMS的爱恨情仇(三)
  • 【c++八股文】Day2:虚函数表和虚函数表指针
  • RPC启动机制及注解实现
  • day 50
  • 0:0 error Parsing error: Cannot read properties of undefined (reading ‘map‘)
  • Rust 学习笔记:通过异步实现并发
  • C语言学习20250611
  • 亮数据抓取浏览器,亚马逊数据采集实战
  • Flask 报错修复实战:send_file() got an unexpected keyword argument ‘etag‘
  • vite原理
  • MFC 第1章:适配 WIndows 编程的软件界面调整
  • 创建和运行线程
  • app制作器软件下载/seo搜索引擎优化介绍
  • 电脑做视频的网站吗/网站排名英文
  • 淘客网站建设要求/上海百度提升优化
  • 微信官方网站网址/企业为何选择网站推广外包?
  • 网站建设公司导航/长沙官网网站推广优化
  • 济南哪家公司做网站好/企业邮箱入口