From 1477b89aa40688210a0af9c6b819fd50ebdcc0fb Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Mon, 8 Aug 2022 08:59:30 +0000 Subject: [PATCH] fix amp train for re --- tools/infer/predict_det_eval.py | 363 ---------------------- tools/infer/predict_rec_eval.py | 534 -------------------------------- 2 files changed, 897 deletions(-) delete mode 100755 tools/infer/predict_det_eval.py delete mode 100755 tools/infer/predict_rec_eval.py diff --git a/tools/infer/predict_det_eval.py b/tools/infer/predict_det_eval.py deleted file mode 100755 index d1f83203..00000000 --- a/tools/infer/predict_det_eval.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import cv2 -import numpy as np -import time -import sys - -import tools.infer.utility as utility -from ppocr.utils.logging import get_logger -from ppocr.utils.utility import get_image_file_list, check_and_read_gif -from ppocr.data import create_operators, transform -from ppocr.postprocess import build_post_process -import json -logger = get_logger() - - -class TextDetector(object): - def __init__(self, args): - self.args = args - self.det_algorithm = args.det_algorithm - self.use_onnx = args.use_onnx - pre_process_list = [{ - 'DetResizeForTest': { - 'limit_side_len': args.det_limit_side_len, - 'limit_type': args.det_limit_type, - } - }, { - 'NormalizeImage': { - 'std': [0.229, 0.224, 0.225], - 'mean': [0.485, 0.456, 0.406], - 'scale': '1./255.', - 'order': 'hwc' - } - }, { - 'ToCHWImage': None - }, { - 'KeepKeys': { - 'keep_keys': ['image', 'shape'] - } - }] - postprocess_params = {} - if self.det_algorithm == "DB": - postprocess_params['name'] = 'DBPostProcess' - postprocess_params["thresh"] = args.det_db_thresh - postprocess_params["box_thresh"] = args.det_db_box_thresh - postprocess_params["max_candidates"] = 1000 - postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio - postprocess_params["use_dilation"] = args.use_dilation - postprocess_params["score_mode"] = args.det_db_score_mode - elif self.det_algorithm == "EAST": - postprocess_params['name'] = 'EASTPostProcess' - postprocess_params["score_thresh"] = args.det_east_score_thresh - postprocess_params["cover_thresh"] = args.det_east_cover_thresh - postprocess_params["nms_thresh"] = args.det_east_nms_thresh - elif self.det_algorithm == "SAST": - pre_process_list[0] = { - 'DetResizeForTest': { - 'resize_long': args.det_limit_side_len - } - } - postprocess_params['name'] = 'SASTPostProcess' - postprocess_params["score_thresh"] = args.det_sast_score_thresh - postprocess_params["nms_thresh"] = args.det_sast_nms_thresh - self.det_sast_polygon = args.det_sast_polygon - if self.det_sast_polygon: - postprocess_params["sample_pts_num"] = 6 - postprocess_params["expand_scale"] = 1.2 - postprocess_params["shrink_ratio_of_width"] = 0.2 - else: - postprocess_params["sample_pts_num"] = 2 - postprocess_params["expand_scale"] = 1.0 - postprocess_params["shrink_ratio_of_width"] = 0.3 - elif self.det_algorithm == "PSE": - postprocess_params['name'] = 'PSEPostProcess' - postprocess_params["thresh"] = args.det_pse_thresh - postprocess_params["box_thresh"] = args.det_pse_box_thresh - postprocess_params["min_area"] = args.det_pse_min_area - postprocess_params["box_type"] = args.det_pse_box_type - postprocess_params["scale"] = args.det_pse_scale - self.det_pse_box_type = args.det_pse_box_type - elif self.det_algorithm == "FCE": - pre_process_list[0] = { - 'DetResizeForTest': { - 'rescale_img': [1080, 736] - } - } - postprocess_params['name'] = 'FCEPostProcess' - postprocess_params["scales"] = args.scales - postprocess_params["alpha"] = args.alpha - postprocess_params["beta"] = args.beta - postprocess_params["fourier_degree"] = args.fourier_degree - postprocess_params["box_type"] = args.det_fce_box_type - else: - logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) - sys.exit(0) - - self.preprocess_op = create_operators(pre_process_list) - self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( - args, 'det', logger) - - if self.use_onnx: - img_h, img_w = self.input_tensor.shape[2:] - if img_h is not None and img_w is not None and img_h > 0 and img_w > 0: - pre_process_list[0] = { - 'DetResizeForTest': { - 'image_shape': [img_h, img_w] - } - } - self.preprocess_op = create_operators(pre_process_list) - - if args.benchmark: - import auto_log - pid = os.getpid() - gpu_id = utility.get_infer_gpuid() - self.autolog = auto_log.AutoLogger( - model_name="det", - model_precision=args.precision, - batch_size=1, - data_shape="dynamic", - save_path=None, - inference_config=self.config, - pids=pid, - process_name=None, - gpu_ids=gpu_id if args.use_gpu else None, - time_keys=[ - 'preprocess_time', 'inference_time', 'postprocess_time' - ], - warmup=2, - logger=logger) - - def order_points_clockwise(self, pts): - rect = np.zeros((4, 2), dtype="float32") - s = pts.sum(axis=1) - rect[0] = pts[np.argmin(s)] - rect[2] = pts[np.argmax(s)] - tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) - diff = np.diff(np.array(tmp), axis=1) - rect[1] = tmp[np.argmin(diff)] - rect[3] = tmp[np.argmax(diff)] - return rect - - def clip_det_res(self, points, img_height, img_width): - for pno in range(points.shape[0]): - points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) - points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) - return points - - def filter_tag_det_res(self, dt_boxes, image_shape): - img_height, img_width = image_shape[0:2] - dt_boxes_new = [] - for box in dt_boxes: - box = self.order_points_clockwise(box) - box = self.clip_det_res(box, img_height, img_width) - rect_width = int(np.linalg.norm(box[0] - box[1])) - rect_height = int(np.linalg.norm(box[0] - box[3])) - if rect_width <= 3 or rect_height <= 3: - continue - dt_boxes_new.append(box) - dt_boxes = np.array(dt_boxes_new) - return dt_boxes - - def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): - img_height, img_width = image_shape[0:2] - dt_boxes_new = [] - for box in dt_boxes: - box = self.clip_det_res(box, img_height, img_width) - dt_boxes_new.append(box) - dt_boxes = np.array(dt_boxes_new) - return dt_boxes - - def __call__(self, img): - ori_im = img.copy() - data = {'image': img} - - st = time.time() - - if self.args.benchmark: - self.autolog.times.start() - - data = transform(data, self.preprocess_op) - img, shape_list = data - if img is None: - return None, 0 - img = np.expand_dims(img, axis=0) - shape_list = np.expand_dims(shape_list, axis=0) - img = img.copy() - - if self.args.benchmark: - self.autolog.times.stamp() - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = img - outputs = self.predictor.run(self.output_tensors, input_dict) - else: - self.input_tensor.copy_from_cpu(img) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.args.benchmark: - self.autolog.times.stamp() - - preds = {} - if self.det_algorithm == "EAST": - preds['f_geo'] = outputs[0] - preds['f_score'] = outputs[1] - elif self.det_algorithm == 'SAST': - preds['f_border'] = outputs[0] - preds['f_score'] = outputs[1] - preds['f_tco'] = outputs[2] - preds['f_tvo'] = outputs[3] - elif self.det_algorithm in ['DB', 'PSE']: - preds['maps'] = outputs[0] - elif self.det_algorithm == 'FCE': - for i, output in enumerate(outputs): - preds['level_{}'.format(i)] = output - else: - raise NotImplementedError - - #self.predictor.try_shrink_memory() - post_result = self.postprocess_op(preds, shape_list) - dt_boxes = post_result[0]['points'] - if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( - self.det_algorithm in ["PSE", "FCE"] and - self.postprocess_op.box_type == 'poly'): - dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) - else: - dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) - - if self.args.benchmark: - self.autolog.times.end(stamp=True) - et = time.time() - return dt_boxes, et - st - - -if __name__ == "__main__": - from ppocr.metrics.eval_det_iou import DetectionIoUEvaluator - evaluator = DetectionIoUEvaluator() - args = utility.parse_args() - - # image_file_list = get_image_file_list(args.image_dir) - def _check_image_file(path): - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} - return any([path.lower().endswith(e) for e in img_end]) - - def get_image_file_list_from_txt(img_file): - imgs_lists = [] - label_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} - root_dir = img_file.split('/')[0] - with open(img_file, 'r') as f: - lines = f.readlines() - for line in lines: - line = line.replace('\n', '').split('\t') - file_path, label = line[0], line[1] - file_path = os.path.join(root_dir, file_path) - if os.path.isfile(file_path) and _check_image_file(file_path): - imgs_lists.append(file_path) - label_lists.append(label) - - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - return imgs_lists, label_lists - - image_file_list, label_list = get_image_file_list_from_txt(args.image_dir) - - text_detector = TextDetector(args) - count = 0 - total_time = 0 - draw_img_save = "./inference_results" - - if args.warmup: - img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) - for i in range(2): - res = text_detector(img) - - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) - save_results = [] - results = [] - for idx in range(len(image_file_list)): - image_file = image_file_list[idx] - label = json.loads(label_list[idx]) - img, flag = check_and_read_gif(image_file) - if not flag: - img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - st = time.time() - dt_boxes, _ = text_detector(img) - elapse = time.time() - st - if count > 0: - total_time += elapse - count += 1 - save_pred = os.path.basename(image_file) + "\t" + str( - json.dumps([x.tolist() for x in dt_boxes])) + "\n" - save_results.append(save_pred) - - # for eval - gt_info_list = [] - det_info_list = [] - for dt_box in dt_boxes: - det_info = { - 'points': np.array( - dt_box, dtype=np.float32), - 'text': '' - } - det_info_list.append(det_info) - for lab in label: - gt_info = { - 'points': np.array( - lab['points'], dtype=np.float32), - 'text': '', - 'ignore': False - } - gt_info_list.append(gt_info) - result = evaluator.evaluate_image(gt_info_list, det_info_list) - results.append(result) - - metrics = evaluator.combine_results(results) - print('predict det eval on ', args.image_dir) - print('metrics: ', metrics) - -# logger.info(save_pred) -# logger.info("The predict time of {}: {}".format(image_file, elapse)) -# src_im = utility.draw_text_det_res(dt_boxes, image_file) -# img_name_pure = os.path.split(image_file)[-1] -# img_path = os.path.join(draw_img_save, -# "det_res_{}".format(img_name_pure)) -# cv2.imwrite(img_path, src_im) -# logger.info("The visualized image saved in {}".format(img_path)) - -# with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: -# f.writelines(save_results) -# f.close() -# if args.benchmark: -# text_detector.autolog.report() diff --git a/tools/infer/predict_rec_eval.py b/tools/infer/predict_rec_eval.py deleted file mode 100755 index 3150d11d..00000000 --- a/tools/infer/predict_rec_eval.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import sys -from PIL import Image -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import cv2 -import numpy as np -import math -import time -import traceback -import paddle - -import tools.infer.utility as utility -from ppocr.postprocess import build_post_process -from ppocr.utils.logging import get_logger -from ppocr.utils.utility import get_image_file_list, check_and_read_gif - -logger = get_logger() - - -class TextRecognizer(object): - def __init__(self, args): - self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] - self.rec_batch_num = args.rec_batch_num - self.rec_algorithm = args.rec_algorithm - postprocess_params = { - 'name': 'CTCLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - if self.rec_algorithm == "SRN": - postprocess_params = { - 'name': 'SRNLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - elif self.rec_algorithm == "RARE": - postprocess_params = { - 'name': 'AttnLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - elif self.rec_algorithm == 'NRTR': - postprocess_params = { - 'name': 'NRTRLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - elif self.rec_algorithm == "SAR": - postprocess_params = { - 'name': 'SARLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - elif self.rec_algorithm == 'ViTSTR': - postprocess_params = { - 'name': 'ViTSTRLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - elif self.rec_algorithm == 'ABINet': - postprocess_params = { - 'name': 'ABINetLabelDecode', - "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char - } - self.postprocess_op = build_post_process(postprocess_params) - self.predictor, self.input_tensor, self.output_tensors, self.config = \ - utility.create_predictor(args, 'rec', logger) - self.benchmark = args.benchmark - self.use_onnx = args.use_onnx - if args.benchmark: - import auto_log - pid = os.getpid() - gpu_id = utility.get_infer_gpuid() - self.autolog = auto_log.AutoLogger( - model_name="rec", - model_precision=args.precision, - batch_size=args.rec_batch_num, - data_shape="dynamic", - save_path=None, #args.save_log_path, - inference_config=self.config, - pids=pid, - process_name=None, - gpu_ids=gpu_id if args.use_gpu else None, - time_keys=[ - 'preprocess_time', 'inference_time', 'postprocess_time' - ], - warmup=0, - logger=logger) - - def resize_norm_img(self, img, max_wh_ratio): - imgC, imgH, imgW = self.rec_image_shape - if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR': - img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - # return padding_im - image_pil = Image.fromarray(np.uint8(img)) - if self.rec_algorithm == 'ViTSTR': - img = image_pil.resize([imgW, imgH], Image.BICUBIC) - else: - img = image_pil.resize([imgW, imgH], Image.ANTIALIAS) - img = np.array(img) - norm_img = np.expand_dims(img, -1) - norm_img = norm_img.transpose((2, 0, 1)) - if self.rec_algorithm == 'ViTSTR': - norm_img = norm_img.astype(np.float32) / 255. - else: - norm_img = norm_img.astype(np.float32) / 128. - 1. - return norm_img - - assert imgC == img.shape[2] - imgW = int((imgH * max_wh_ratio)) - if self.use_onnx: - w = self.input_tensor.shape[3:][0] - if w is not None and w > 0: - imgW = w - - h, w = img.shape[:2] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - if self.rec_algorithm == 'RARE': - if resized_w > self.rec_image_shape[2]: - resized_w = self.rec_image_shape[2] - imgW = self.rec_image_shape[2] - resized_image = cv2.resize(img, (resized_w, imgH)) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) - padding_im[:, :, 0:resized_w] = resized_image - return padding_im - - def resize_norm_img_srn(self, img, image_shape): - imgC, imgH, imgW = image_shape - - img_black = np.zeros((imgH, imgW)) - im_hei = img.shape[0] - im_wid = img.shape[1] - - if im_wid <= im_hei * 1: - img_new = cv2.resize(img, (imgH * 1, imgH)) - elif im_wid <= im_hei * 2: - img_new = cv2.resize(img, (imgH * 2, imgH)) - elif im_wid <= im_hei * 3: - img_new = cv2.resize(img, (imgH * 3, imgH)) - else: - img_new = cv2.resize(img, (imgW, imgH)) - - img_np = np.asarray(img_new) - img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) - img_black[:, 0:img_np.shape[1]] = img_np - img_black = img_black[:, :, np.newaxis] - - row, col, c = img_black.shape - c = 1 - - return np.reshape(img_black, (c, row, col)).astype(np.float32) - - def srn_other_inputs(self, image_shape, num_heads, max_text_length): - - imgC, imgH, imgW = image_shape - feature_dim = int((imgH / 8) * (imgW / 8)) - - encoder_word_pos = np.array(range(0, feature_dim)).reshape( - (feature_dim, 1)).astype('int64') - gsrm_word_pos = np.array(range(0, max_text_length)).reshape( - (max_text_length, 1)).astype('int64') - - gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) - gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( - [-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias1 = np.tile( - gsrm_slf_attn_bias1, - [1, num_heads, 1, 1]).astype('float32') * [-1e9] - - gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( - [-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias2 = np.tile( - gsrm_slf_attn_bias2, - [1, num_heads, 1, 1]).astype('float32') * [-1e9] - - encoder_word_pos = encoder_word_pos[np.newaxis, :] - gsrm_word_pos = gsrm_word_pos[np.newaxis, :] - - return [ - encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2 - ] - - def process_image_srn(self, img, image_shape, num_heads, max_text_length): - norm_img = self.resize_norm_img_srn(img, image_shape) - norm_img = norm_img[np.newaxis, :] - - [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ - self.srn_other_inputs(image_shape, num_heads, max_text_length) - - gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) - gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) - encoder_word_pos = encoder_word_pos.astype(np.int64) - gsrm_word_pos = gsrm_word_pos.astype(np.int64) - - return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2) - - def resize_norm_img_sar(self, img, image_shape, - width_downsample_ratio=0.25): - imgC, imgH, imgW_min, imgW_max = image_shape - h = img.shape[0] - w = img.shape[1] - valid_ratio = 1.0 - # make sure new_width is an integral multiple of width_divisor. - width_divisor = int(1 / width_downsample_ratio) - # resize - ratio = w / float(h) - resize_w = math.ceil(imgH * ratio) - if resize_w % width_divisor != 0: - resize_w = round(resize_w / width_divisor) * width_divisor - if imgW_min is not None: - resize_w = max(imgW_min, resize_w) - if imgW_max is not None: - valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) - resize_w = min(imgW_max, resize_w) - resized_image = cv2.resize(img, (resize_w, imgH)) - resized_image = resized_image.astype('float32') - # norm - if image_shape[0] == 1: - resized_image = resized_image / 255 - resized_image = resized_image[np.newaxis, :] - else: - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - resize_shape = resized_image.shape - padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) - padding_im[:, :, 0:resize_w] = resized_image - pad_shape = padding_im.shape - - return padding_im, resize_shape, pad_shape, valid_ratio - - def resize_norm_img_svtr(self, img, image_shape): - - imgC, imgH, imgW = image_shape - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - return resized_image - - def resize_norm_img_abinet(self, img, image_shape): - - imgC, imgH, imgW = image_shape - - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') - resized_image = resized_image / 255. - - mean = np.array([0.485, 0.456, 0.406]) - std = np.array([0.229, 0.224, 0.225]) - resized_image = ( - resized_image - mean[None, None, ...]) / std[None, None, ...] - resized_image = resized_image.transpose((2, 0, 1)) - resized_image = resized_image.astype('float32') - - return resized_image - - def __call__(self, img_list): - img_num = len(img_list) - # Calculate the aspect ratio of all text bars - width_list = [] - for img in img_list: - width_list.append(img.shape[1] / float(img.shape[0])) - # Sorting can speed up the recognition process - indices = np.argsort(np.array(width_list)) - rec_res = [['', 0.0]] * img_num - batch_num = self.rec_batch_num - st = time.time() - if self.benchmark: - self.autolog.times.start() - for beg_img_no in range(0, img_num, batch_num): - end_img_no = min(img_num, beg_img_no + batch_num) - norm_img_batch = [] - imgC, imgH, imgW = self.rec_image_shape - max_wh_ratio = imgW / imgH - # max_wh_ratio = 0 - for ino in range(beg_img_no, end_img_no): - h, w = img_list[indices[ino]].shape[0:2] - wh_ratio = w * 1.0 / h - max_wh_ratio = max(max_wh_ratio, wh_ratio) - for ino in range(beg_img_no, end_img_no): - - if self.rec_algorithm == "SAR": - norm_img, _, _, valid_ratio = self.resize_norm_img_sar( - img_list[indices[ino]], self.rec_image_shape) - norm_img = norm_img[np.newaxis, :] - valid_ratio = np.expand_dims(valid_ratio, axis=0) - valid_ratios = [] - valid_ratios.append(valid_ratio) - norm_img_batch.append(norm_img) - elif self.rec_algorithm == "SRN": - norm_img = self.process_image_srn( - img_list[indices[ino]], self.rec_image_shape, 8, 25) - encoder_word_pos_list = [] - gsrm_word_pos_list = [] - gsrm_slf_attn_bias1_list = [] - gsrm_slf_attn_bias2_list = [] - encoder_word_pos_list.append(norm_img[1]) - gsrm_word_pos_list.append(norm_img[2]) - gsrm_slf_attn_bias1_list.append(norm_img[3]) - gsrm_slf_attn_bias2_list.append(norm_img[4]) - norm_img_batch.append(norm_img[0]) - elif self.rec_algorithm == "SVTR": - norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], - self.rec_image_shape) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - elif self.rec_algorithm == "ABINet": - norm_img = self.resize_norm_img_abinet( - img_list[indices[ino]], self.rec_image_shape) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - else: - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - norm_img_batch = np.concatenate(norm_img_batch) - norm_img_batch = norm_img_batch.copy() - if self.benchmark: - self.autolog.times.stamp() - - if self.rec_algorithm == "SRN": - encoder_word_pos_list = np.concatenate(encoder_word_pos_list) - gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) - gsrm_slf_attn_bias1_list = np.concatenate( - gsrm_slf_attn_bias1_list) - gsrm_slf_attn_bias2_list = np.concatenate( - gsrm_slf_attn_bias2_list) - - inputs = [ - norm_img_batch, - encoder_word_pos_list, - gsrm_word_pos_list, - gsrm_slf_attn_bias1_list, - gsrm_slf_attn_bias2_list, - ] - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(self.output_tensors, - input_dict) - preds = {"predict": outputs[2]} - else: - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle( - input_names[i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.benchmark: - self.autolog.times.stamp() - preds = {"predict": outputs[2]} - elif self.rec_algorithm == "SAR": - valid_ratios = np.concatenate(valid_ratios) - inputs = [ - norm_img_batch, - valid_ratios, - ] - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(self.output_tensors, - input_dict) - preds = outputs[0] - else: - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle( - input_names[i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.benchmark: - self.autolog.times.stamp() - preds = outputs[0] - else: - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(self.output_tensors, - input_dict) - preds = outputs[0] - else: - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if self.benchmark: - self.autolog.times.stamp() - if len(outputs) != 1: - preds = outputs - else: - preds = outputs[0] - rec_result = self.postprocess_op(preds) - for rno in range(len(rec_result)): - rec_res[indices[beg_img_no + rno]] = rec_result[rno] - if self.benchmark: - self.autolog.times.end(stamp=True) - return rec_res, time.time() - st - - -def main(args): - # image_file_list = get_image_file_list(args.image_dir) - - def _check_image_file(path): - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} - return any([path.lower().endswith(e) for e in img_end]) - - def get_image_file_list_from_txt(img_file): - imgs_lists = [] - label_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} - root_dir = img_file.split('/')[0] - with open(img_file, 'r') as f: - lines = f.readlines() - for line in lines: - line = line.replace('\n', '').split('\t') - file_path, label = line[0], line[1] - file_path = os.path.join(root_dir, file_path) - if os.path.isfile(file_path) and _check_image_file(file_path): - imgs_lists.append(file_path) - label_lists.append(label) - - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - return imgs_lists, label_lists - - image_file_list, label_list = get_image_file_list_from_txt(args.image_dir) - - text_recognizer = TextRecognizer(args) - valid_image_file_list = [] - img_list = [] - - logger.info( - "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " - "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320" - ) - # warmup 2 times - if args.warmup: - img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8) - for i in range(2): - res = text_recognizer([img] * int(args.rec_batch_num)) - - for image_file in image_file_list: - img, flag = check_and_read_gif(image_file) - if not flag: - img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - valid_image_file_list.append(image_file) - img_list.append(img) - - try: - rec_res, _ = text_recognizer(img_list) - except Exception as E: - logger.info(traceback.format_exc()) - logger.info(E) - exit() - correct_num = 0 - for ino in range(len(img_list)): - pred = rec_res[ino][0] - gt = label_list[ino] - if pred == gt: - correct_num += 1 - acc = correct_num * 1.0 / len(img_list) - print('predict rec eval on ', args.image_dir) - print('acc: ', acc) - - # for debug bad case - bad_case_lines = [] - for ino in range(len(img_list)): - pred = rec_res[ino][0] - gt = label_list[ino] - if pred != gt and len(gt) <= 25: - bad_case = valid_image_file_list[ - ino] + '\t' + 'pred:' + pred + '\t' + 'gt:' + gt + '\n' - bad_case_lines.append(bad_case) - - with open('bad_case_hwdb2.txt', 'a+') as f: - f.writelines(bad_case_lines) - # end debug case - - if args.benchmark: - text_recognizer.autolog.report() - - -if __name__ == "__main__": - main(utility.parse_args()) -- GitLab