当前位置: 首页 > news >正文

【TrOCR】用Transformer和torch库实现TrOCR模型

项目结构:

TrOCR/
├── config.py               # 所有配置参数(路径、超参数等)
├── dataset.py              # 数据集类 + 数据增强(合并 data_augmentation)
├── model.py                # 模型加载与配置
├── utils.py                # 通用工具(含日志功能,合并 logger.py)
├── train.py                # 训练逻辑 + 入口(合并 trainer.py)
├── predict.py              # 推理接口 + 入口(合并 inference.py)
├── evaluate.py             # 评估指标 + 入口(合并 metrics.py)
├── requirements.txt        # 依赖库
├── README.md               # 项目说明
├── data/                   # 数据集
│   ├── train/(images + labels.json)
│   ├── val/(images + labels.json)
│   └── test/(images + labels.json)
├── models/                 # 保存训练好的模型
└── logs/                   # 训练日志、评估报告

数据集的数据结构

我的数据集路径:C:\Users\Virgil\Desktop\dataetOCR\ChineseOcr2k,目录下有train和val两个文件夹,分别是images和labels.json。

标签JSON文件的数据结构:labels.json内容是这样的:

    [{"file_name": "20587062_124836763.jpg","text": "设施一流的绿色、舒适"},{"file_name": "20487921_757563219.jpg","text": "瘦削,二○○六年五月"},{"file_name": "20567468_1494490742.jpg","text": "分置改革方案》在法规"...

损失函数:
TrOCR 官方使用的损失函数是交叉熵损失(Cross-Entropy Loss),
主要用于计算解码器生成文本与真实标签之间的差异,具体是通过 标签移位(label shifting) 策略实现的序列到序列(Seq2Seq)损失计算。
TrOCR 是典型的编码器 - 解码器架构(图像编码器 + 文本解码器),
其损失计算逻辑与大多数 Seq2Seq 模型一致:

  • 输入与标签设计:解码器的输入是 “真实文本标签左移一位 + 起始符号(如 [CLS])”,标签是 “真实文本标签 + 终止符号(如 [SEP])”。
  • 损失计算:对解码器每个时间步的输出 logits 计算交叉熵损失,忽略 padding 位置(通过将 pad token 替换为 -100 实现,PyTorch 会自动忽略 -100 标签的损失)。
http://www.dtcms.com/a/342300.html

相关文章:

  • yggjs_rlayout 科技风主题布局使用教程
  • StarRocks不能启动 ,StarRocksFe节点不能启动问题 处理
  • macos使用FFmpeg与SDL解码并播放H.265视频
  • 【TrOCR】模型预训练权重各个文件说明
  • 从800米到2000米:耐达讯自动化Profibus转光纤如何让软启动器效率翻倍?
  • 表达式(CSP-J 2021-Expr)题目详解
  • Django的生命周期
  • 如何在DHTMLX Scheduler中实现带拖拽的任务待办区(Backlog)
  • 非常飘逸的 Qt 菜单控件
  • logger级别及大小
  • 如何安装和配置W3 Total Cache以提升WordPress网站性能
  • C++设计模式--策略模式与观察者模式
  • 小红书AI落地与前端开发技术全解析(From AI)
  • Python 正则表达式(更长的正则表达式示例)
  • 【基础排序】CF - 赌场游戏Playing in a Casino
  • 机器学习4
  • 精算中的提升曲线(Lift Curve)与机器学习中的差别
  • 网络打印机安装操作指南
  • 健康常识查询系统|基于java和小程序的健康常识查询系统设计与实现(源码+数据库+文档)
  • CentOS7安装部署PostgreSQL
  • 《PostgreSQL内核学习:slot_deform_heap_tuple 的分支消除与特化路径优化》
  • ES_文档
  • 2025-08-21 Python进阶6——迭代器生成器与with
  • Python项目开发- 动态设置工作目录与模块搜索路径
  • strerror和perror函数的使用及其联系和区别
  • 43-Python基础语法-3
  • QWidget/QMainWindow与QLayout的布局
  • CSDN使用技巧
  • Pandas中数据分组进阶以及数据透视表
  • 链表-143.重排链表-力扣(LeetCode)