当前位置: 首页 > news >正文

文字检测到文字识别

paint_rec文件夹下
文字检测输出16个点,根据16个点对框进行矫正
text_align.py

import numpy as np
from shapely.geometry import *
from scipy.special import comb as n_over_k
import math
import json
import os
import cv2
import torch
from torch import nn
import copyMtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]class Bezier(nn.Module):def __init__(self, ps, ctps):super(Bezier, self).__init__()self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))self.x0 = ps[0, 0]self.x3 = ps[-1, 0]self.y0 = ps[0, 1]self.y3 = ps[-1, 1]self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)self.t = torch.as_tensor(np.linspace(0, 1, 81))def forward(self):x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()t = self.tbezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))bezier = torch.stack((bezier_x, bezier_y), dim=1)diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)sdiffs = diffs ** 2dists = sdiffs.sum(dim=2).sqrt()min_dists, min_inds = dists.min(dim=1)return min_dists.sum()def control_points(self):return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3def control_points_f(self):return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3def bezier_fit(x, y):dy = y[1:] - y[:-1]dx = x[1:] - x[:-1]dt = (dx ** 2 + dy ** 2)**0.5t = dt/dt.sum()t = np.hstack(([0], t))t = t.cumsum()data = np.column_stack((x, y))Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)control_points = Pseudoinverse.dot(data)     # (4,9)*(9,2) -> (4,2)medi_ctp = control_points[1:-1,:].flatten().tolist()return medi_ctpdef is_close_to_linev2(xs, ys, size, thres=0.05):pts = []nor_pixel = int(size ** 0.5)for i in range(len(xs)):pts.append(Point([xs[i], ys[i]]))import itertools# iterate by pairs of pointsslopes = [(second.y - first.y) / (second.x - first.x) if not (second.x - first.x) == 0.0 else math.inf * np.sign((second.y - first.y)) for first, second in zip(pts, pts[1:])]st_slope = (ys[-1] - ys[0]) / (xs[-1] - xs[0])max_dis = ((ys[-1] - ys[0]) ** 2 + (xs[-1] - xs[0]) ** 2) ** (0.5)diffs = abs(slopes - st_slope)score = diffs.sum() * max_dis / nor_pixelif score < thres:return 0.0else:return 3.0def train(x, y, ctps, lr):x, y = np.array(x), np.array(y)ps = np.vstack((x, y)).transpose()bezier = Bezier(ps, ctps)optimizer = torch.optim.SGD(bezier.parameters(), lr=lr)# start = time.time()# save initial pointsintial_pts = bezier.control_points_f()if not lr == 0.0:for i in range(1000):loss = bezier()if torch.isnan(loss):return intial_ptsif i == 400: optimizer.param_groups[0]['lr'] *= 0.5if i == 800: optimizer.param_groups[0]['lr'] *= 0.5optimizer.zero_grad()loss.backward()optimizer.step()# end = time.time()return bezier.control_points_f()def bezier_to_poly(bezier,tw):# bezier to polygon, 修改600可以获取采样点数量u = np.linspace(0, 1, tw)bezier = bezier.reshape(2, 4, 2).transpose(0, 2, 1).reshape(4, 4)points = np.outer((1 - u) ** 3, bezier[:, 0]) \+ np.outer(3 * u * ((1 - u) ** 2), bezier[:, 1]) \+ np.outer(3 * (u ** 2) * (1 - u), bezier[:, 2]) \+ np.outer(u ** 3, bezier[:, 3])points = np.concatenate((points[:, :2], points[:, 2:]), axis=0)return pointsdef double_linear2(input_signal, output_signal, beziers):'''双线性插值:param input_signal: 输入图像:param output_signal: 输出图像-->空白图像,定义图像大小:return: 双线性插值后的图像'''input_signal_cp = np.copy(input_signal)  # 输入图像的副本srch,srcw,_ = input_signal_cp.shape# 输出图像的尺寸output_row, output_col, _ = output_signal.shapetp_beziers = beziers[:int(beziers.shape[0] / 2), :]bp_beziers = beziers[int(len(beziers) / 2):, :]# print('input_signal_cp is ', input_signal_cp.shape, '  output_signal ', output_signal.shape)for j in range(output_col):# 取点,待优化# print(tp_beziers[j][0], tp_beziers[j][1],bp_beziers[len(bp_beziers) - j - 1][0], bp_beziers[len(bp_beziers) - j - 1][1])col_points = np.linspace((tp_beziers[j][0], tp_beziers[j][1]),(bp_beziers[len(bp_beziers) - j - 1][0], bp_beziers[len(bp_beziers) - j - 1][1]),output_row, dtype=np.int32)# print('col_points[:, 1] ',col_points[:, 1])try:output_signal[:, j,:] = input_signal_cp[col_points[:, 1], col_points[:, 0],:]except:output_signal[:, j, :] = input_signal_cp[np.clip(col_points[:, 1], 0, srch-1), np.clip(col_points[:, 0], 0, srcw-1), :]return output_signaldef draw_color_circle(predictions, img,th,tw):img = img.astype(np.uint8)beziers = np.array(predictions)line_imgs = []for bezier in beziers:polygon = bezier_to_poly(bezier,tw)output_img = np.zeros((th,tw, 3))line_img = double_linear2(img, output_img, polygon)# cv2.imwrite('res3.jpg',line_img)line_imgs.append(line_img)return line_imgsdef dot_product_angle(v1, v2):if np.linalg.norm(v1) == 0 or np.linalg.norm(v2) == 0:print("Zero magnitude vector!")else:vector_dot_product = np.dot(v1, v2)arccos = np.arccos(vector_dot_product / (np.linalg.norm(v1) * np.linalg.norm(v2)))angle = np.degrees(arccos)return anglereturn 0def clockwise_angle(v1, v2):x1, y1 = v1x2, y2 = v2dot = x1 * x2 + y1 * y2det = x1 * y2 - y1 * x2theta = np.arctan2(det, dot)theta = theta if theta > 0 else 2 * np.pi + thetareturn thetadef perspect(img: np.ndarray, src_points: np.ndarray, dst_points: np.ndarray, out_w: int, out_h: int):src_points = np.asarray(src_points, dtype=np.float32).reshape(4, 2)dst_points = np.asarray(dst_points, dtype=np.float32).reshape(4, 2)M = cv2.getPerspectiveTransform(src_points, dst_points)plate_img = cv2.warpPerspective(img, M, (out_w, out_h))return plate_imgdef get_recognitionimage(srcimage, allpoints):'''allpoints : [x1,y1,x2,y2,.......,x16,y16]共16个点的坐标imgpath: 原始图片的路径Returns:  矫正后的图片'''image = copy.deepcopy(srcimage)points = []for i in range(16):points.append([int(allpoints[i*2]),int(allpoints[i*2 + 1])])points = np.array(points)if len(points)==16:line_up_x = points[:8][:,0]line_up_y = points[:8][:,1]line_down_x = points[8:][:,0]line_down_y = points[8:][:,1]points1 = copy.deepcopy(points)tw = int(((points1[0][0] - points1[7][0]) ** 2 + (points1[0][1] - points1[7][1]) ** 2) ** 0.5)th = int(((points1[0][0] - points1[-1][0]) ** 2 + (points1[0][1] - points1[-1][1]) ** 2) ** 0.5)# if th < 10 or tw<30:#     breakpoint()init_control_points = bezier_fit(np.array(line_up_x), np.array(line_up_y))learning_rate = is_close_to_linev2(np.array(line_up_x), np.array(line_up_y), image.size)x0, x1, x2, x3, y0, y1, y2, y3 = train(np.array(line_up_x), np.array(line_up_y), init_control_points, 0.0)control_points = np.array([[x0, y0], \[x1, y1], \[x2, y2], \[x3, y3]])init_control_points2 = bezier_fit(np.array(line_down_x), np.array(line_down_y))learning_rate2 = is_close_to_linev2(np.array(line_down_x), np.array(line_down_y), image.size)x0_2, x1_2, x2_2, x3_2, y0_2, y1_2, y2_2, y3_2 = train(np.array(line_down_x), np.array(line_down_y),init_control_points2, 0.0)control_points2 = np.array([[x0_2, y0_2], \[x1_2, y1_2], \[x2_2, y2_2], \[x3_2, y3_2]])predictions = [[]]predictions[0].append([x0, y0, x1, y1, x2, y2, x3, y3, x0_2, y0_2, x1_2, y1_2, x2_2, y2_2, x3_2, y3_2])line_imgs = draw_color_circle(predictions, image,th,tw)newimg = line_imgs[0]return newimg.astype(np.uint8)return None```python
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import numpy as np
# import paddle
# from paddle.nn import functional as F
import reclass BaseRecLabelDecode(object):""" Convert between text-label and text-index """def __init__(self, character_dict_path=None, use_space_char=False):self.beg_str = "sos"self.end_str = "eos"self.character_str = []if character_dict_path is None:self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"dict_character = list(self.character_str)else:with open(character_dict_path, "rb") as fin:lines = fin.readlines()for line in lines:line = line.decode('utf-8').strip("\n").strip("\r\n")self.character_str.append(line)if use_space_char:self.character_str.append(" ")dict_character = list(self.character_str)dict_character = self.add_special_char(dict_character)self.dict = {}for i, char in enumerate(dict_character):self.dict[char] = iself.character = dict_characterdef add_special_char(self, dict_character):return dict_characterdef decode(self, text_index, text_prob=None, is_remove_duplicate=False):""" convert text-index into text-label. """result_list = []ignored_tokens = self.get_ignored_tokens()batch_size = len(text_index)for batch_idx in range(batch_size):selection = np.ones(len(text_index[batch_idx]), dtype=bool)if is_remove_duplicate:selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]for ignored_token in ignored_tokens:selection &= text_index[batch_idx] != ignored_tokenchar_list = [self.character[text_id]for text_id in text_index[batch_idx][selection]]if text_prob is not None:conf_list = text_prob[batch_idx][selection]else:conf_list = [1] * len(selection)if not len(conf_list):conf_list = [0]text = ''.join(char_list)result_list.append((text, np.mean(conf_list).tolist()))return result_listdef get_ignored_tokens(self):return [0]  # for ctc blankclass CTCLabelDecode(BaseRecLabelDecode):""" Convert between text-label and text-index """def __init__(self, character_dict_path=None, use_space_char=False,**kwargs):super(CTCLabelDecode, self).__init__(character_dict_path,use_space_char)def __call__(self, preds, label=None, *args, **kwargs):if isinstance(preds, tuple) or isinstance(preds, list):preds = preds[-1]# print(type(preds))# if isinstance(preds, paddle.Tensor):#     preds = preds.numpy()preds_idx = preds.argmax(axis=2)preds_prob = preds.max(axis=2)text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)if label is None:return textlabel = self.decode(label)return text, labeldef decode2(self, preds, topk=1):if isinstance(preds, tuple) or isinstance(preds, list):preds = preds[-1]preds_sort = preds.argsort(axis=2)# print(preds_sort.shape)seq_arange = np.arange(preds.shape[1])result_list = []ignored_tokens = self.get_ignored_tokens()for batch_idx in range(len(preds)):text_index = preds_sort[batch_idx, :, -1]# print(text_index.shape)# quit()selection = np.ones(len(text_index), dtype=bool)selection[1:] = text_index[1:] != text_index[:-1]for ignored_token in ignored_tokens:selection &= text_index != ignored_tokenchar_list = [self.character[text_id] for text_id in text_index[selection]]conf_list = preds[batch_idx, seq_arange, text_index][selection]text = ''.join(char_list)score = np.mean(conf_list).tolist()res = [text]if topk > 1:for i in range(topk-1):text_index = preds_sort[batch_idx, :, -(i+2)]char_list = [self.character[text_id] if text_id not in ignored_tokens else "~" for text_id in text_index[selection]]conf_list = preds[batch_idx, seq_arange, text_index][selection]text = ''.join(char_list)score = np.mean(conf_list).tolist()res.append(text)result_list.append(res)r = []result_list = [list(zip(*res)) for res in result_list]return result_listdef add_special_char(self, dict_character):dict_character = ['blank'] + dict_characterreturn dict_character# class DistillationCTCLabelDecode(CTCLabelDecode):
#     """
#     Convert 
#     Convert between text-label and text-index
#     """#     def __init__(self,
#                  character_dict_path=None,
#                  use_space_char=False,
#                  model_name=["student"],
#                  key=None,
#                  multi_head=False,
#                  **kwargs):
#         super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
#                                                          use_space_char)
#         if not isinstance(model_name, list):
#             model_name = [model_name]
#         self.model_name = model_name#         self.key = key
#         self.multi_head = multi_head#     def __call__(self, preds, label=None, *args, **kwargs):
#         output = dict()
#         for name in self.model_name:
#             pred = preds[name]
#             if self.key is not None:
#                 pred = pred[self.key]
#             if self.multi_head and isinstance(pred, dict):
#                 pred = pred['ctc']
#             output[name] = super().__call__(pred, label=label, *args, **kwargs)
#         return output# class NRTRLabelDecode(BaseRecLabelDecode):
#     """ Convert between text-label and text-index """#     def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
#         super(NRTRLabelDecode, self).__init__(character_dict_path,
#                                               use_space_char)#     def __call__(self, preds, label=None, *args, **kwargs):#         if len(preds) == 2:
#             preds_id = preds[0]
#             preds_prob = preds[1]
#             if isinstance(preds_id, paddle.Tensor):
#                 preds_id = preds_id.numpy()
#             if isinstance(preds_prob, paddle.Tensor):
#                 preds_prob = preds_prob.numpy()
#             if preds_id[0][0] == 2:
#                 preds_idx = preds_id[:, 1:]
#                 preds_prob = preds_prob[:, 1:]
#             else:
#                 preds_idx = preds_id
#             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
#             if label is None:
#                 return text
#             label = self.decode(label[:, 1:])
#         else:
#             if isinstance(preds, paddle.Tensor):
#                 preds = preds.numpy()
#             preds_idx = preds.argmax(axis=2)
#             preds_prob = preds.max(axis=2)
#             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
#             if label is None:
#                 return text
#             label = self.decode(label[:, 1:])
#         return text, label#     def add_special_char(self, dict_character):
#         dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
#         return dict_character#     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
#         """ convert text-index into text-label. """
#         result_list = []
#         batch_size = len(text_index)
#         for batch_idx in range(batch_size):
#             char_list = []
#             conf_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 if text_index[batch_idx][idx] == 3:  # end
#                     break
#                 try:
#                     char_list.append(self.character[int(text_index[batch_idx][
#                         idx])])
#                 except:
#                     continue
#                 if text_prob is not None:
#                     conf_list.append(text_prob[batch_idx][idx])
#                 else:
#                     conf_list.append(1)
#             text = ''.join(char_list)
#             result_list.append((text.lower(), np.mean(conf_list).tolist()))
#         return result_list# class AttnLabelDecode(BaseRecLabelDecode):
#     """ Convert between text-label and text-index """#     def __init__(self, character_dict_path=None, use_space_char=False,
#                  **kwargs):
#         super(AttnLabelDecode, self).__init__(character_dict_path,
#                                               use_space_char)#     def add_special_char(self, dict_character):
#         self.beg_str = "sos"
#         self.end_str = "eos"
#         dict_character = dict_character
#         dict_character = [self.beg_str] + dict_character + [self.end_str]
#         return dict_character#     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
#         """ convert text-index into text-label. """
#         result_list = []
#         ignored_tokens = self.get_ignored_tokens()
#         [beg_idx, end_idx] = self.get_ignored_tokens()
#         batch_size = len(text_index)
#         for batch_idx in range(batch_size):
#             char_list = []
#             conf_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 if text_index[batch_idx][idx] in ignored_tokens:
#                     continue
#                 if int(text_index[batch_idx][idx]) == int(end_idx):
#                     break
#                 if is_remove_duplicate:
#                     # only for predict
#                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
#                             batch_idx][idx]:
#                         continue
#                 char_list.append(self.character[int(text_index[batch_idx][
#                     idx])])
#                 if text_prob is not None:
#                     conf_list.append(text_prob[batch_idx][idx])
#                 else:
#                     conf_list.append(1)
#             text = ''.join(char_list)
#             result_list.append((text, np.mean(conf_list).tolist()))
#         return result_list#     def __call__(self, preds, label=None, *args, **kwargs):
#         """
#         text = self.decode(text)
#         if label is None:
#             return text
#         else:
#             label = self.decode(label, is_remove_duplicate=False)
#             return text, label
#         """
#         if isinstance(preds, paddle.Tensor):
#             preds = preds.numpy()#         preds_idx = preds.argmax(axis=2)
#         preds_prob = preds.max(axis=2)
#         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
#         if label is None:
#             return text
#         label = self.decode(label, is_remove_duplicate=False)
#         return text, label#     def get_ignored_tokens(self):
#         beg_idx = self.get_beg_end_flag_idx("beg")
#         end_idx = self.get_beg_end_flag_idx("end")
#         return [beg_idx, end_idx]#     def get_beg_end_flag_idx(self, beg_or_end):
#         if beg_or_end == "beg":
#             idx = np.array(self.dict[self.beg_str])
#         elif beg_or_end == "end":
#             idx = np.array(self.dict[self.end_str])
#         else:
#             assert False, "unsupport type %s in get_beg_end_flag_idx" \
#                           % beg_or_end
#         return idx# class SEEDLabelDecode(BaseRecLabelDecode):
#     """ Convert between text-label and text-index """#     def __init__(self, character_dict_path=None, use_space_char=False,
#                  **kwargs):
#         super(SEEDLabelDecode, self).__init__(character_dict_path,
#                                               use_space_char)#     def add_special_char(self, dict_character):
#         self.padding_str = "padding"
#         self.end_str = "eos"
#         self.unknown = "unknown"
#         dict_character = dict_character + [
#             self.end_str, self.padding_str, self.unknown
#         ]
#         return dict_character#     def get_ignored_tokens(self):
#         end_idx = self.get_beg_end_flag_idx("eos")
#         return [end_idx]#     def get_beg_end_flag_idx(self, beg_or_end):
#         if beg_or_end == "sos":
#             idx = np.array(self.dict[self.beg_str])
#         elif beg_or_end == "eos":
#             idx = np.array(self.dict[self.end_str])
#         else:
#             assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
#         return idx#     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
#         """ convert text-index into text-label. """
#         result_list = []
#         [end_idx] = self.get_ignored_tokens()
#         batch_size = len(text_index)
#         for batch_idx in range(batch_size):
#             char_list = []
#             conf_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 if int(text_index[batch_idx][idx]) == int(end_idx):
#                     break
#                 if is_remove_duplicate:
#                     # only for predict
#                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
#                             batch_idx][idx]:
#                         continue
#                 char_list.append(self.character[int(text_index[batch_idx][
#                     idx])])
#                 if text_prob is not None:
#                     conf_list.append(text_prob[batch_idx][idx])
#                 else:
#                     conf_list.append(1)
#             text = ''.join(char_list)
#             result_list.append((text, np.mean(conf_list).tolist()))
#         return result_list#     def __call__(self, preds, label=None, *args, **kwargs):
#         """
#         text = self.decode(text)
#         if label is None:
#             return text
#         else:
#             label = self.decode(label, is_remove_duplicate=False)
#             return text, label
#         """
#         preds_idx = preds["rec_pred"]
#         if isinstance(preds_idx, paddle.Tensor):
#             preds_idx = preds_idx.numpy()
#         if "rec_pred_scores" in preds:
#             preds_idx = preds["rec_pred"]
#             preds_prob = preds["rec_pred_scores"]
#         else:
#             preds_idx = preds["rec_pred"].argmax(axis=2)
#             preds_prob = preds["rec_pred"].max(axis=2)
#         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
#         if label is None:
#             return text
#         label = self.decode(label, is_remove_duplicate=False)
#         return text, label# class SRNLabelDecode(BaseRecLabelDecode):
#     """ Convert between text-label and text-index """#     def __init__(self, character_dict_path=None, use_space_char=False,
#                  **kwargs):
#         super(SRNLabelDecode, self).__init__(character_dict_path,
#                                              use_space_char)
#         self.max_text_length = kwargs.get('max_text_length', 25)#     def __call__(self, preds, label=None, *args, **kwargs):
#         pred = preds['predict']
#         char_num = len(self.character_str) + 2
#         if isinstance(pred, paddle.Tensor):
#             pred = pred.numpy()
#         pred = np.reshape(pred, [-1, char_num])#         preds_idx = np.argmax(pred, axis=1)
#         preds_prob = np.max(pred, axis=1)#         preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])#         preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])#         text = self.decode(preds_idx, preds_prob)#         if label is None:
#             text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
#             return text
#         label = self.decode(label)
#         return text, label#     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
#         """ convert text-index into text-label. """
#         result_list = []
#         ignored_tokens = self.get_ignored_tokens()
#         batch_size = len(text_index)#         for batch_idx in range(batch_size):
#             char_list = []
#             conf_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 if text_index[batch_idx][idx] in ignored_tokens:
#                     continue
#                 if is_remove_duplicate:
#                     # only for predict
#                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
#                             batch_idx][idx]:
#                         continue
#                 char_list.append(self.character[int(text_index[batch_idx][
#                     idx])])
#                 if text_prob is not None:
#                     conf_list.append(text_prob[batch_idx][idx])
#                 else:
#                     conf_list.append(1)#             text = ''.join(char_list)
#             result_list.append((text, np.mean(conf_list).tolist()))
#         return result_list#     def add_special_char(self, dict_character):
#         dict_character = dict_character + [self.beg_str, self.end_str]
#         return dict_character#     def get_ignored_tokens(self):
#         beg_idx = self.get_beg_end_flag_idx("beg")
#         end_idx = self.get_beg_end_flag_idx("end")
#         return [beg_idx, end_idx]#     def get_beg_end_flag_idx(self, beg_or_end):
#         if beg_or_end == "beg":
#             idx = np.array(self.dict[self.beg_str])
#         elif beg_or_end == "end":
#             idx = np.array(self.dict[self.end_str])
#         else:
#             assert False, "unsupport type %s in get_beg_end_flag_idx" \
#                           % beg_or_end
#         return idx# class TableLabelDecode(object):
#     """  """#     def __init__(self, character_dict_path, **kwargs):
#         list_character, list_elem = self.load_char_elem_dict(
#             character_dict_path)
#         list_character = self.add_special_char(list_character)
#         list_elem = self.add_special_char(list_elem)
#         self.dict_character = {}
#         self.dict_idx_character = {}
#         for i, char in enumerate(list_character):
#             self.dict_idx_character[i] = char
#             self.dict_character[char] = i
#         self.dict_elem = {}
#         self.dict_idx_elem = {}
#         for i, elem in enumerate(list_elem):
#             self.dict_idx_elem[i] = elem
#             self.dict_elem[elem] = i#     def load_char_elem_dict(self, character_dict_path):
#         list_character = []
#         list_elem = []
#         with open(character_dict_path, "rb") as fin:
#             lines = fin.readlines()
#             substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
#                 "\t")
#             character_num = int(substr[0])
#             elem_num = int(substr[1])
#             for cno in range(1, 1 + character_num):
#                 character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
#                 list_character.append(character)
#             for eno in range(1 + character_num, 1 + character_num + elem_num):
#                 elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
#                 list_elem.append(elem)
#         return list_character, list_elem#     def add_special_char(self, list_character):
#         self.beg_str = "sos"
#         self.end_str = "eos"
#         list_character = [self.beg_str] + list_character + [self.end_str]
#         return list_character#     def __call__(self, preds):
#         structure_probs = preds['structure_probs']
#         loc_preds = preds['loc_preds']
#         if isinstance(structure_probs, paddle.Tensor):
#             structure_probs = structure_probs.numpy()
#         if isinstance(loc_preds, paddle.Tensor):
#             loc_preds = loc_preds.numpy()
#         structure_idx = structure_probs.argmax(axis=2)
#         structure_probs = structure_probs.max(axis=2)
#         structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
#             structure_idx, structure_probs, 'elem')
#         res_html_code_list = []
#         res_loc_list = []
#         batch_num = len(structure_str)
#         for bno in range(batch_num):
#             res_loc = []
#             for sno in range(len(structure_str[bno])):
#                 text = structure_str[bno][sno]
#                 if text in ['<td>', '<td']:
#                     pos = structure_pos[bno][sno]
#                     res_loc.append(loc_preds[bno, pos])
#             res_html_code = ''.join(structure_str[bno])
#             res_loc = np.array(res_loc)
#             res_html_code_list.append(res_html_code)
#             res_loc_list.append(res_loc)
#         return {
#             'res_html_code': res_html_code_list,
#             'res_loc': res_loc_list,
#             'res_score_list': result_score_list,
#             'res_elem_idx_list': result_elem_idx_list,
#             'structure_str_list': structure_str
#         }#     def decode(self, text_index, structure_probs, char_or_elem):
#         """convert text-label into text-index.
#         """
#         if char_or_elem == "char":
#             current_dict = self.dict_idx_character
#         else:
#             current_dict = self.dict_idx_elem
#             ignored_tokens = self.get_ignored_tokens('elem')
#             beg_idx, end_idx = ignored_tokens#         result_list = []
#         result_pos_list = []
#         result_score_list = []
#         result_elem_idx_list = []
#         batch_size = len(text_index)
#         for batch_idx in range(batch_size):
#             char_list = []
#             elem_pos_list = []
#             elem_idx_list = []
#             score_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 tmp_elem_idx = int(text_index[batch_idx][idx])
#                 if idx > 0 and tmp_elem_idx == end_idx:
#                     break
#                 if tmp_elem_idx in ignored_tokens:
#                     continue#                 char_list.append(current_dict[tmp_elem_idx])
#                 elem_pos_list.append(idx)
#                 score_list.append(structure_probs[batch_idx, idx])
#                 elem_idx_list.append(tmp_elem_idx)
#             result_list.append(char_list)
#             result_pos_list.append(elem_pos_list)
#             result_score_list.append(score_list)
#             result_elem_idx_list.append(elem_idx_list)
#         return result_list, result_pos_list, result_score_list, result_elem_idx_list#     def get_ignored_tokens(self, char_or_elem):
#         beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
#         end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
#         return [beg_idx, end_idx]#     def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
#         if char_or_elem == "char":
#             if beg_or_end == "beg":
#                 idx = self.dict_character[self.beg_str]
#             elif beg_or_end == "end":
#                 idx = self.dict_character[self.end_str]
#             else:
#                 assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
#                               % beg_or_end
#         elif char_or_elem == "elem":
#             if beg_or_end == "beg":
#                 idx = self.dict_elem[self.beg_str]
#             elif beg_or_end == "end":
#                 idx = self.dict_elem[self.end_str]
#             else:
#                 assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
#                               % beg_or_end
#         else:
#             assert False, "Unsupport type %s in char_or_elem" \
#                           % char_or_elem
#         return idx# class SARLabelDecode(BaseRecLabelDecode):
#     """ Convert between text-label and text-index """#     def __init__(self, character_dict_path=None, use_space_char=False,
#                  **kwargs):
#         super(SARLabelDecode, self).__init__(character_dict_path,
#                                              use_space_char)#         self.rm_symbol = kwargs.get('rm_symbol', False)#     def add_special_char(self, dict_character):
#         beg_end_str = "<BOS/EOS>"
#         unknown_str = "<UKN>"
#         padding_str = "<PAD>"
#         dict_character = dict_character + [unknown_str]
#         self.unknown_idx = len(dict_character) - 1
#         dict_character = dict_character + [beg_end_str]
#         self.start_idx = len(dict_character) - 1
#         self.end_idx = len(dict_character) - 1
#         dict_character = dict_character + [padding_str]
#         self.padding_idx = len(dict_character) - 1
#         return dict_character#     def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
#         """ convert text-index into text-label. """
#         result_list = []
#         ignored_tokens = self.get_ignored_tokens()#         batch_size = len(text_index)
#         for batch_idx in range(batch_size):
#             char_list = []
#             conf_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 if text_index[batch_idx][idx] in ignored_tokens:
#                     continue
#                 if int(text_index[batch_idx][idx]) == int(self.end_idx):
#                     if text_prob is None and idx == 0:
#                         continue
#                     else:
#                         break
#                 if is_remove_duplicate:
#                     # only for predict
#                     if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
#                             batch_idx][idx]:
#                         continue
#                 char_list.append(self.character[int(text_index[batch_idx][
#                     idx])])
#                 if text_prob is not None:
#                     conf_list.append(text_prob[batch_idx][idx])
#                 else:
#                     conf_list.append(1)
#             text = ''.join(char_list)
#             if self.rm_symbol:
#                 comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
#                 text = text.lower()
#                 text = comp.sub('', text)
#             result_list.append((text, np.mean(conf_list).tolist()))
#         return result_list#     def __call__(self, preds, label=None, *args, **kwargs):
#         if isinstance(preds, paddle.Tensor):
#             preds = preds.numpy()
#         preds_idx = preds.argmax(axis=2)
#         preds_prob = preds.max(axis=2)#         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)#         if label is None:
#             return text
#         label = self.decode(label, is_remove_duplicate=False)
#         return text, label#     def get_ignored_tokens(self):
#         return [self.padding_idx]# class DistillationSARLabelDecode(SARLabelDecode):
#     """
#     Convert 
#     Convert between text-label and text-index
#     """#     def __init__(self,
#                  character_dict_path=None,
#                  use_space_char=False,
#                  model_name=["student"],
#                  key=None,
#                  multi_head=False,
#                  **kwargs):
#         super(DistillationSARLabelDecode, self).__init__(character_dict_path,
#                                                          use_space_char)
#         if not isinstance(model_name, list):
#             model_name = [model_name]
#         self.model_name = model_name#         self.key = key
#         self.multi_head = multi_head#     def __call__(self, preds, label=None, *args, **kwargs):
#         output = dict()
#         for name in self.model_name:
#             pred = preds[name]
#             if self.key is not None:
#                 pred = pred[self.key]
#             if self.multi_head and isinstance(pred, dict):
#                 pred = pred['sar']
#             output[name] = super().__call__(pred, label=label, *args, **kwargs)
#         return output# class PRENLabelDecode(BaseRecLabelDecode):
#     """ Convert between text-label and text-index """#     def __init__(self, character_dict_path=None, use_space_char=False,
#                  **kwargs):
#         super(PRENLabelDecode, self).__init__(character_dict_path,
#                                               use_space_char)#     def add_special_char(self, dict_character):
#         padding_str = '<PAD>'  # 0 
#         end_str = '<EOS>'  # 1
#         unknown_str = '<UNK>'  # 2#         dict_character = [padding_str, end_str, unknown_str] + dict_character
#         self.padding_idx = 0
#         self.end_idx = 1
#         self.unknown_idx = 2#         return dict_character#     def decode(self, text_index, text_prob=None):
#         """ convert text-index into text-label. """
#         result_list = []
#         batch_size = len(text_index)#         for batch_idx in range(batch_size):
#             char_list = []
#             conf_list = []
#             for idx in range(len(text_index[batch_idx])):
#                 if text_index[batch_idx][idx] == self.end_idx:
#                     break
#                 if text_index[batch_idx][idx] in \
#                     [self.padding_idx, self.unknown_idx]:
#                     continue
#                 char_list.append(self.character[int(text_index[batch_idx][
#                     idx])])
#                 if text_prob is not None:
#                     conf_list.append(text_prob[batch_idx][idx])
#                 else:
#                     conf_list.append(1)#             text = ''.join(char_list)
#             if len(text) > 0:
#                 result_list.append((text, np.mean(conf_list).tolist()))
#             else:
#                 # here confidence of empty recog result is 1
#                 result_list.append(('', 1))
#         return result_list#     def __call__(self, preds, label=None, *args, **kwargs):
#         preds = preds.numpy()
#         preds_idx = preds.argmax(axis=2)
#         preds_prob = preds.max(axis=2)
#         text = self.decode(preds_idx, preds_prob)
#         if label is None:
#             return text
#         label = self.decode(label)
#         return text, label
paint_fmt.py```python
# coding=utf-8
import cv2
import re
import numpy as np
from PIL import Image, ImageFont, ImageDraw# from common.log import get_logger# logger = get_logger()class PaintFmt:COLORS = [(255, 0, 0), (0, 255, 0), (0, 0 ,255),(255, 255, 0), (0, 255, 255), (255, 0 ,255),(127, 127, 0), (0, 127, 127), (127, 0 ,127),(255, 127, 0), (0, 255, 127), (255, 0 ,127),]def process(self, results):res = set()for i, each in enumerate(results):text = each.textbox = each.boxcrop = each.croptext = text_process(text)res.add(text)return list(res)def process_with_img(self, results, img, name="paint"):res = set()h, w, _ = img.shapeleft, top = 50, 50  # w / 20, h / 20right, bottem = w - left, h - top# print(results)texts = []for i, each in enumerate(results):text = each.textbox = each.boxcrop = each.cropbox = box.reshape(16, 2).astype(int)x1, y1 = box.min(axis=0)x2, y2 = box.max(axis=0)if not (left < x1 and top < y1 and x2 < right and y2 < bottem):continuemask = np.zeros((h, w), dtype=np.uint8)cv2.drawContours(mask, [box], -1, 1, -1)mask = mask.astype(bool)color = [np.random.randint(0, 255) for _ in range(3)]cv2.drawContours(img, [box], -1, color, 2)img = img.astype(np.float32)img[mask, :] += [-30, -30, 255]img = np.clip(img, 0, 255)img = img.astype(np.uint8)texts.append([text, tuple(box[np.argmin(box[:, 1])]), color])img = put_text(img, texts)cv2.imwrite(f"show/{name}.jpg", img)res = list(res)# logger.info(f"paint fmt {res}")return resdef put_text(img, texts):fontpath = "/usr/share/fonts/truetype/arphic/ukai.ttc"font = ImageFont.truetype(fontpath, 50)img_pil = Image.fromarray(img)draw = ImageDraw.Draw(img_pil)for text, point, color in texts:x, y = pointpoint = (x, y-50)draw.text(point, text, font=font, fill=color)img = np.array(img_pil)return imgzhiliang = ['最大允许质量', '最大允许总质量', '总质量', '质量', '最大', '最大允许牵引力', '牵引', '允许牵引力', '牵引力', '引力', '拖挂总质量', '最大拖挂总质量','最大允许拖挂总质量', '核载质量','最大允许牵引质量','准拖挂车总质量','准拖总质量','最大允许重量', '最大允许总重量', '总重量', '重量', '拖挂总重量', '最大拖挂总重量', '最大允许拖挂总重量', '核载重量']
renshu = ['准核载人数', '核载人数', '准载', '最大乘坐人', '乘坐人', '最大乘坐人数', '乘坐人数', '人数', '乘坐', '乘座', '乘座人', '座位', '限载']
rongji = ['罐体总容积', '总容积', '罐体容积', '容积', '有效罐体容积', '有效容积', '有效总容积']
labbangaodu = ['栏板高度', '栏板', '高度', '栏板高']shuzi_reg = re.compile(r"\d+[.|\d]?[\d]+")
zhongwen_reg = re.compile(u"[\u4e00-\u9fa5]+")
shuzidanwei_reg = re.compile(r"\d+[.|\d+]?(\d+)?[m|M|mm|MM]")def text_process(text):text = text.replace(" ", "")#处理数字m = shuzi_reg.search(text)if m is not None:m = m.group()else:m = ""#处理中文zt = zhongwen_reg.search(text)if zt is not None:zt = zt.group()else:zt = ""# print(zt, m)if not len(zt) and len(m):if 'K' in text.upper() or 'G' in text.upper():danwei = 'KG'return m + danweielif float(m) < 100 and ('T' in text.upper() or '吨' in text):danwei = 'T'return m + danweims = shuzidanwei_reg.search(text)if ms is not None:return ms.group()return textzt = (zt.replace('后量','质量').replace('高座','高度'))if zt in zhiliang and len(m):danwei = 'KG'if float(m) < 100 and ('T' in text.upper() or '吨' in text):danwei = 'T'elif 'K' not in text.upper() and 'G' not in text.upper():danwei = ''if zt in zhiliang:return zt + m + danweielif zt in renshu and len(m):danwei = '人'return zt + m + danweielif zt in rongji and len(m):danwei = '立方米'return zt + m + danweielif zt in labbangaodu and len(m):danwei = 'mm'if '.' in m:danwei = 'm'return zt + m + danweielse:return text

predict.py


import os
import torch
from PIL import Image
import cv2
import numpy as np
import math
import time
import onnxruntime as ortfrom .rec_postprocess import CTCLabelDecodefrom .text_align import get_recognitionimage
from detectron2.engine.defaults import DefaultPredictorfrom adet.config import get_cfg
from easydict import EasyDict as edictdef setup_cfg(config_file, confidence_threshold, model_data=None):# load config from file and command-line argumentscfg = get_cfg()cfg.merge_from_file(config_file)# Set score_threshold for builtin modelsif model_data:cfg.MODEL.WEIGHTS = model_data# if device:#     cfg.MODEL.DEVICE = deviceprint(f'weight:{cfg.MODEL.WEIGHTS}')print(f'device:{cfg.MODEL.DEVICE}')cfg.MODEL.RETINANET.SCORE_THRESH_TEST = confidence_thresholdcfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_thresholdcfg.MODEL.FCOS.INFERENCE_TH_TEST = confidence_thresholdcfg.MODEL.MEInst.INFERENCE_TH_TEST = confidence_thresholdcfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_thresholdcfg.freeze()return cfgclass DP_det():def __init__(self, model_data=None, conf_thres = 0.3, device='0'):os.environ["CUDA_VISIBLE_DEVICES"] = devicecfg_file = './config/R_50_poly.yaml'cfg = setup_cfg(cfg_file, conf_thres, model_data)self.predictor = DefaultPredictor(cfg)def predict(self, image):predictions = self.predictor(image)points = predictions['instances'].get('polygons').cpu().numpy().astype(np.int32).tolist()scores = predictions['instances'].scores.tolist()scores = [round(x, 2) for x in scores]for i in range(len(points)):points[i].append(scores[i])# print(predictions)# print(points)return pointsdef process(self, image, is_show=True):preds = self.predict(image)res_crops = []for i, pred in enumerate(preds):pts = pred[:-1]pts = np.asarray(pts)print(pts)print(len(pts))bh = np.sqrt(np.square(pts[0:2] - pts[30:32]).sum())if bh < 10:continuecrop = get_recognitionimage(image, pts)if crop is None:continueres_crops.append(edict(crop=crop, box=pts))if is_show:return res_crops, imageelse:return res_crops, Noneclass PPocrRec:def __init__(self, *args, **kwargs):if kwargs["model_path"].endswith(".jit"):self.net = GPUTextRecognizer(*args, **kwargs)else:self.net = TextRecognizer(*args, **kwargs)def process_one_img(self, image):pred = self.net(image)text = pred[0][0]prob = pred[0][1]return text, prob# return pred[0][0]def process_batch(self, images: list):imgs = [x.crop for x in images]results = self.net.predict_batch(imgs)if len(results) == len(images):res = []for i, each in enumerate(images):text, score = results[i]score = np.mean(score)if not len(text):continue# logger.info(f"PPocrRec2 ({score:.3f}): {text}")each.text = textres.append(each)return reselse:return self.process(images)def process(self, images: list):res = []result = []for each in images:img = each.croptext, score = self.net(img)[0]score = np.mean(score)if not len(text):continuetext = full2half(text)each.text = texteach.score = scoreresult.append(each)return result# return resdef process_ocr_car(self, dets):res = []for one in dets:crop = one.cropbox = one.boxpreb = self.process_one_img(crop)one.text = preb[0]one.score = preb[1]res.append(one)# res.append([preb[0], preb[1], box])return resdef full2half(s):n = ""for char in s:num = ord(char)if num == 0x3000:num = 32elif 0xff01 <= num <= 0xff5e:num -= 0xfee0num = chr(num)n += numreturn nclass TextRecognizer(object):def __init__(self, model_path, dict_path, image_shape=(3, 32, 320), device=-1):self.rec_image_shape = image_shapeself.postprocess_op = CTCLabelDecode(dict_path, use_space_char=True)if not os.path.exists(model_path):raise ValueError("not find model file path {}".format(model_path))if device is not None and device >= 0:self.predictor = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider'])else:self.predictor = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])# providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']self.input_tensor, self.output_tensors = self.predictor.get_inputs()[0], Noneself.rec_batch_num = 6def resize_norm_img(self, img, max_wh_ratio):imgC, imgH, imgW = self.rec_image_shapeassert imgC == img.shape[2]imgW = int((imgH * max_wh_ratio))h, w = img.shape[:2]ratio = w / float(h)if math.ceil(imgH * ratio) > imgW:resized_w = imgWelse:resized_w = int(math.ceil(imgH * ratio))resized_image = cv2.resize(img, (resized_w, imgH))resized_image = resized_image.astype('float32')resized_image = resized_image.transpose((2, 0, 1)) / 255resized_image -= 0.5resized_image /= 0.5padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)padding_im[:, :, 0:resized_w] = resized_imagereturn padding_imdef __call__(self, img):return self.predict(img)def predict(self, img):return self.predict_batch([img])def predict_batch(self, img_list):img_num = len(img_list)# Calculate the aspect ratio of all text barswidth_list = []for img in img_list:width_list.append(img.shape[1] / float(img.shape[0]))# Sorting can speed up the recognition processindices = np.argsort(np.array(width_list))rec_res = [['', 0.0]] * img_numbatch_num = self.rec_batch_numst = time.time()for beg_img_no in range(0, img_num, batch_num):end_img_no = min(img_num, beg_img_no + batch_num)norm_img_batch = []imgC, imgH, imgW = self.rec_image_shapemax_wh_ratio = imgW / imgH# max_wh_ratio = 0for ino in range(beg_img_no, end_img_no):h, w = img_list[indices[ino]].shape[0:2]wh_ratio = w * 1.0 / hmax_wh_ratio = max(max_wh_ratio, wh_ratio)for ino in range(beg_img_no, end_img_no):norm_img = self.resize_norm_img(img_list[indices[ino]],max_wh_ratio)norm_img = norm_img[np.newaxis, :]norm_img_batch.append(norm_img)norm_img_batch = np.concatenate(norm_img_batch)norm_img_batch = norm_img_batch.copy()input_dict = {}input_dict[self.input_tensor.name] = norm_img_batchoutputs = self.predictor.run(self.output_tensors,input_dict)preds = outputs[0]rec_result = self.postprocess_op(preds)for rno in range(len(rec_result)):rec_res[indices[beg_img_no + rno]] = rec_result[rno]return rec_resclass GPUTextRecognizer:def __init__(self, model_path, dict_path, device='cuda:0', image_shape=(3, 48, 320)):rec_image_shape1 = "3, 48, 320"self.rec_image_shape1 = [int(v) for v in rec_image_shape1.split(",")]self.character_type = 'ch'self.rec_batch_num = 8self.rec_algorithm = 'CRNN'self.max_text_length = 25self.use_space_char = Falseself.use_space_char1 = Trueself.rec_char_dict_path1 = dict_pathpostprocess_params1 = {'name': 'CTCLabelDecode',"character_type": self.character_type,"character_dict_path": self.rec_char_dict_path1,"use_space_char": self.use_space_char1}self.postprocess_op1 = CTCLabelDecode(**postprocess_params1)use_gpu = Trueself.use_gpu = use_gpuself.device = deviceself.limited_max_width = 1280self.limited_min_width = 16self.net1 = torch.jit.load(model_path)self.net1.to(self.device)self.net1.eval()def __call__(self, img):return self.predict(img)def post_process1(self, preds):return self.postprocess_op1(preds)def pre_process1(self, img):# Calculate the aspect ratio of all text barswidth_list = []width_list.append(img.shape[1] / float(img.shape[0]))# rec_res = []norm_img_batch = []max_wh_ratio = 0h, w = img.shape[0:2]wh_ratio = w * 1.0 / hmax_wh_ratio = max(max_wh_ratio, wh_ratio)norm_img = self.resize_norm_img1(img, max_wh_ratio)norm_img = norm_img[np.newaxis, :]norm_img_batch.append(norm_img)norm_img_batch = np.concatenate(norm_img_batch)norm_img_batch = norm_img_batch.copy()return norm_img_batchdef resize_norm_img1(self, img, max_wh_ratio):imgC, imgH, imgW = self.rec_image_shape1assert imgC == img.shape[2]max_wh_ratio = max(max_wh_ratio, imgW / imgH)imgW = int((32 * max_wh_ratio))imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)h, w = img.shape[:2]ratio = w / float(h)ratio_imgH = math.ceil(imgH * ratio)ratio_imgH = max(ratio_imgH, self.limited_min_width)if ratio_imgH > imgW:resized_w = imgWelse:resized_w = int(ratio_imgH)resized_image = cv2.resize(img, (resized_w, imgH))resized_image = resized_image.astype('float32')resized_image = resized_image.transpose((2, 0, 1)) / 255resized_image -= 0.5resized_image /= 0.5padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)padding_im[:, :, 0:resized_w] = resized_imagereturn padding_imdef predict(self, img):img1 = img.copy()norm_img_batch1 = self.pre_process1(img1)with torch.no_grad():inp1 = torch.from_numpy(norm_img_batch1).to(self.device)prob_out1 = self.net1(inp1)if isinstance(prob_out1, list):preds1 = [v.cpu().numpy() for v in prob_out1]else:preds1 = prob_out1.cpu().numpy()rec_result1 = self.post_process1(preds1)return rec_result1def predict_batch(self, img_list):img_num = len(img_list)# Calculate the aspect ratio of all text barswidth_list = []for img in img_list:width_list.append(img.shape[1] / float(img.shape[0]))# Sorting can speed up the recognition processindices = np.argsort(np.array(width_list))rec_res = [['', 0.0]] * img_numbatch_num = self.rec_batch_numst = time.time()for beg_img_no in range(0, img_num, batch_num):end_img_no = min(img_num, beg_img_no + batch_num)norm_img_batch = []imgC, imgH, imgW = self.rec_image_shape1max_wh_ratio = imgW / imgH# max_wh_ratio = 0for ino in range(beg_img_no, end_img_no):h, w = img_list[indices[ino]].shape[0:2]wh_ratio = w * 1.0 / hmax_wh_ratio = max(max_wh_ratio, wh_ratio)for ino in range(beg_img_no, end_img_no):norm_img = self.resize_norm_img1(img_list[indices[ino]],max_wh_ratio)norm_img = norm_img[np.newaxis, :]norm_img_batch.append(norm_img)norm_img_batch = np.concatenate(norm_img_batch)norm_img_batch = norm_img_batch.copy()with torch.no_grad():inp1 = torch.from_numpy(norm_img_batch).to(self.device)prob_out1 = self.net1(inp1)preds1 = prob_out1.cpu().numpy()rec_result = self.postprocess_op1(preds1)for rno in range(len(rec_result)):rec_res[indices[beg_img_no + rno]] = rec_result[rno]return rec_resclass PaintRec:def __init__(self, model_path: str, alphabet_path: str, device: int = -1):rec_image_shape = "3, 32, 320"self.rec_image_shape = [int(v) for v in rec_image_shape.split(",")]self.character_type = 'ch'self.rec_batch_num = 6self.rec_algorithm = 'CRNN'self.max_text_length = 25self.use_space_char = Trueself.rec_char_dict_path = alphabet_pathpostprocess_params = {'name': 'CTCLabelDecode',"character_type": self.character_type,"character_dict_path": self.rec_char_dict_path,"use_space_char": self.use_space_char}self.postprocess_op = CTCLabelDecode(**postprocess_params)use_gpu = Trueself.use_gpu = use_gpuself.device = f"cuda:{device}" if device >= 0 and torch.cuda.is_available() else "cpu"self.limited_max_width = 1280self.limited_min_width = 16self.net = torch.jit.load(model_path)self.net.to(self.device)self.net.eval()def post_process(self,preds):return self.postprocess_op(preds)def pre_process(self,img):# Calculate the aspect ratio of all text barswidth_list = []width_list.append(img.shape[1] / float(img.shape[0]))# rec_res = []norm_img_batch = []max_wh_ratio = 0h, w = img.shape[0:2]wh_ratio = w * 1.0 / hmax_wh_ratio = max(max_wh_ratio, wh_ratio)norm_img = self.resize_norm_img(img,max_wh_ratio)norm_img = norm_img[np.newaxis, :]norm_img_batch.append(norm_img)norm_img_batch = np.concatenate(norm_img_batch)norm_img_batch = norm_img_batch.copy()return norm_img_batchdef resize_norm_img(self, img, max_wh_ratio):imgC, imgH, imgW = self.rec_image_shapeassert imgC == img.shape[2]max_wh_ratio = max(max_wh_ratio, imgW / imgH)imgW = int((32 * max_wh_ratio))imgW = max(min(imgW, self.limited_max_width), self.limited_min_width)h, w = img.shape[:2]ratio = w / float(h)ratio_imgH = math.ceil(imgH * ratio)ratio_imgH = max(ratio_imgH, self.limited_min_width)if ratio_imgH > imgW:resized_w = imgWelse:resized_w = int(ratio_imgH)resized_image = cv2.resize(img, (resized_w, imgH))resized_image = resized_image.astype('float32')resized_image = resized_image.transpose((2, 0, 1)) / 255resized_image -= 0.5resized_image /= 0.5padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)padding_im[:, :, 0:resized_w] = resized_imagereturn padding_imdef predict(self, img):norm_img_batch =self.pre_process(img)with torch.no_grad():inp = torch.from_numpy(norm_img_batch).to(self.device)prob_out = self.net(inp)if isinstance(prob_out, list):preds = [v.cpu().numpy() for v in prob_out]else:preds = prob_out.cpu().numpy()rec_result = self.post_process(preds)return rec_result[0]def process(self, images):result = []for each in images:img = each.croptext, score = self.predict(img)# if not len(text) or score < 0.5:#     continue# logger.info(f"MingpaiRec ({score:.3f}): {text}")each.text = texteach.score = scoreresult.append(each)return result

predict_test.py

# -*- coding: utf-8 -*-
from glob import glob
import cv2
from paint_rec.predict import PaintRec, PPocrRec, DP_det
from paint_rec.paint_fmt import PaintFmt
import re
import numpy as np
import os
import time
from util import cv2ImgAddText, find_numbers_positions, find_numbers, has_numbers # , DP_det# # 定义文字的字体、大小、颜色和粗细
# font = cv2.FONT_HERSHEY_PLAIN # cv2.FONT_HERSHEY_SIMPLEX
# font_scale = 0.7
# text_color = (255, 0, 0) # (255, 0, 255)  # 白色
# text_thickness = 30class CarPaintOCR():def __init__(self):self.text_det = DP_det()self.net_rec = PPocrRec(model_path="model_data/carPaints_rec_250515_1.jit",dict_path="model_data/ppocr_keys_v1_car.txt",device="cuda:0",)self.paint_fmt = PaintFmt()def predict(self, img, is_show=False):dets, show_img = self.text_det.process(img, is_show=True)recs = self.net_rec.process_ocr_car(dets)res = recshezai_recs = []for one in res:text = one['text'] # one # text_position = one['box'][:2] # text_[-1][:2] box[:2] # ishas_numbers = has_numbers(text)if "人" in text and ishas_numbers:# print('text: ', text)number = find_numbers(text)num_pos = find_numbers_positions(text)if len(num_pos)>=2 and (num_pos[0]+1==num_pos[1]):try:id = num_pos[1] + 1if '人' != text[id]:text[id] = '人'except:id = num_pos[1]try:str_text = text[:id] + "人" + text[id+1:]except:str_text = text[:id] + "人"text = str_textif number == 1: id = num_pos[0]str_text = text[:id] + "7" + text[id+1:]text = str_text# if is_show:#     show_img = cv2ImgAddText(show_img, text, text_position, text_color, text_thickness)hezai_recs.append(text)return hezai_recsdef getFile_names(file_dir, ext=[".jpg", '.png', '.jpeg']):# file_dir = self.xmlRootL = []for root, dirs, files in os.walk(file_dir):# print(files)for file in files:if os.path.splitext(file)[1] in ext:pathlist = os.path.join(root, file)L.append(pathlist) # L.append(os.path.splitext(file)[0])  # L.append(os.path.join(root, file))return Ldef main():show_dir = f'./11_show'# './核载人数测试图片_test_0526/'crop_dir = f'./11_crop'# './核载人数测试图片_test_0526_crop/'net = CarPaintOCR()if not os.path.exists(show_dir):os.makedirs(show_dir)if not os.path.exists(crop_dir):os.makedirs(crop_dir)flist = getFile_names('./test/')  # "/home/data1/smf_data/test_data/000/*.jpg"  "./核载人数测试图片/"for fpath in flist:# print(fpath)_, name = os.path.split(fpath)img = cv2.imread(fpath)res = net.predict(img)print(res)main()
http://www.dtcms.com/a/291200.html

相关文章:

  • 如何用 Z.ai 生成PPT,一句话生成整套演示文档
  • 自反馈机制(Self-Feedback)在大模型中的原理、演进与应用
  • 【PTA数据结构 | C语言版】哥尼斯堡的“七桥问题”
  • 【ROS1】07-话题通信中使用自定义msg
  • (9)机器学习小白入门 YOLOv:YOLOv8-cls 技术解析与代码实现
  • 选择排序 冒泡排序
  • LinkedList与链表(单向)(Java实现)
  • android studio 远程库编译报错无法访问远程库如何解决
  • 算法提升之字符串回文问题-(马拉车算法)
  • Java基础教程(011):面向对象中的构造方法
  • 模拟高负载测试脚本
  • Flink框架:keyBy实现按键逻辑分区
  • 250kHz采样率下多信号参数设置
  • mysql-5.7 Linux安装教程
  • 无人机报警器技术要点与捕捉方式
  • Anaconda 路径精简后暴露 python 及工具到环境变量的配置记录 [二]
  • Linux学习之Linux系统权限
  • scratch音乐会开幕倒计时 2025年6月中国电子学会图形化编程 少儿编程 scratch编程等级考试一级真题和答案解析
  • Git核心功能简要学习
  • 知识 IP 的突围:从 “靠感觉” 到 “系统 + AI” 的变现跃迁
  • 网络编程及原理(八)网络层 IP 协议
  • 关于校准 ARM 开发板时间的步骤和常见问题:我应该是RTC电池没电了才导致我设置了重启开发板又变回去2025年的时间
  • Xilinx FPGA XCKU115‑2FLVA1517I AMD KintexUltraScale
  • 【Java EE】多线程-初阶-Thread 类及常见方法
  • Netty中CompositeByteBuf 的addComponents方法解析
  • PNP加速关断驱动电路
  • [数据结构]#4 用链表实现的栈结构
  • FastAPI 中,数据库模型(通常使用 SQLAlchemy 定义)和接口模型(使用 Pydantic 定义的 schemas)的差异
  • GraphRAG快速入门和原理理解
  • 在线教育如何设置视频问答/视频弹题?——重塑在线教育的互动体验