Paddle导出PP-OCRv5 onnx并推理
一、下载模型
首先我们先下载模型。
PP-OCRv5_mobile_rec_infer.tar
https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-OCRv5_mobile_rec_infer.tar
或者PP-OCRv5_server_rec_infer.tar
https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-OCRv5_server_rec_infer.tar
二、转换成onnx
我们最好先clone,找个文件夹
git clone https://github.com/PaddlePaddle/PaddleOCR
然后创建对应的python虚拟环境,并安装好依赖
然后进入虚拟环境
python -m pip install paddle2onnx==2.0.2rc1
然后进行转换,把路径替换成上面你的下载模型的路径,导出的模型名称为model.onnx
paddle2onnx --model_dir 你的PP-OCRv5_mobile_rec路径 --model_filename inference.json --params_filename inference.pdiparams --save_file model.onnx --opset_version 14 --enable_onnx_checker True
三、运行测试
下载PP-OCRv5字典,下载不了的可以在资源里下载
ppocrv5-onnx/dict/ppocrv5_dict.txt at master · HoVDuc/ppocrv5-onnx
注意安装好依赖和设置路径
import onnxruntime as ort
import numpy as np
import cv2
import os
import time# 1. 加载 ONNX 模型
onnx_path = 'model.onnx'session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])# 2. 加载字典
char_list = []
dict_path = 'ppocrv5_dict.txt'
with open(dict_path, 'r', encoding='utf-8') as f:char_list = [line.strip() for line in f]
char_list.append('') # 防止部分模型输出最大索引时越界def preprocess(img_path):img = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, (320, 48)) # 修正高度为48,宽度320img = img.astype(np.float32) / 255.0mean = np.array([0.5, 0.5, 0.5]).reshape((1, 1, 3))std = np.array([0.5, 0.5, 0.5]).reshape((1, 1, 3))img = (img - mean) / stdimg = img.transpose(2, 0, 1) # HWC to CHWimg = np.expand_dims(img, axis=0)return img.astype(np.float32) # 保证输入为float32def ctc_decode(preds, char_list):text = ''last_index = -1for i in preds:if i != 0 and i != last_index:if 0 < i <= len(char_list):text += char_list[i-1]else:print(f"Warning: index {i} out of range for char_list (len={len(char_list)})")last_index = ireturn text# 3. 遍历 test 文件夹下所有图片
img_dir = 'test'
img_files = [f for f in os.listdir(img_dir) if f.lower().endswith('.png')]start_time = time.time()for img_file in img_files:img_path = os.path.join(img_dir, img_file)img = preprocess(img_path)ort_inputs = {session.get_inputs()[0].name: img}preds = session.run(None, ort_inputs)[0]preds_idx = preds.argmax(axis=2)[0]text = ctc_decode(preds_idx, char_list)print(f'{img_file}: {text}')end_time = time.time()
total_time = end_time - start_time
num_imgs = len(img_files)
if num_imgs > 0:avg_time = total_time / num_imgsprint(f'总共耗时: {total_time:.4f} 秒')print(f'平均每图耗时: {avg_time:.4f} 秒')
else:print('未找到图片,无法计算平均耗时。')