From 03d384860fce041f1b18633c8969c097ba6537d2 Mon Sep 17 00:00:00 2001 From: andyjpaddle Date: Mon, 8 Aug 2022 08:58:07 +0000 Subject: [PATCH] fix amp train for re --- tools/infer/predict_det_eval.py | 363 ++++++++++++++++++++++ tools/infer/predict_rec_eval.py | 534 ++++++++++++++++++++++++++++++++ tools/program.py | 8 +- 3 files changed, 902 insertions(+), 3 deletions(-) create mode 100755 tools/infer/predict_det_eval.py create mode 100755 tools/infer/predict_rec_eval.py diff --git a/tools/infer/predict_det_eval.py b/tools/infer/predict_det_eval.py new file mode 100755 index 00000000..d1f83203 --- /dev/null +++ b/tools/infer/predict_det_eval.py @@ -0,0 +1,363 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time +import sys + +import tools.infer.utility as utility +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process +import json +logger = get_logger() + + +class TextDetector(object): + def __init__(self, args): + self.args = args + self.det_algorithm = args.det_algorithm + self.use_onnx = args.use_onnx + pre_process_list = [{ + 'DetResizeForTest': { + 'limit_side_len': args.det_limit_side_len, + 'limit_type': args.det_limit_type, + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + postprocess_params = {} + if self.det_algorithm == "DB": + postprocess_params['name'] = 'DBPostProcess' + postprocess_params["thresh"] = args.det_db_thresh + postprocess_params["box_thresh"] = args.det_db_box_thresh + postprocess_params["max_candidates"] = 1000 + postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio + postprocess_params["use_dilation"] = args.use_dilation + postprocess_params["score_mode"] = args.det_db_score_mode + elif self.det_algorithm == "EAST": + postprocess_params['name'] = 'EASTPostProcess' + postprocess_params["score_thresh"] = args.det_east_score_thresh + postprocess_params["cover_thresh"] = args.det_east_cover_thresh + postprocess_params["nms_thresh"] = args.det_east_nms_thresh + elif self.det_algorithm == "SAST": + pre_process_list[0] = { + 'DetResizeForTest': { + 'resize_long': args.det_limit_side_len + } + } + postprocess_params['name'] = 'SASTPostProcess' + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + self.det_sast_polygon = args.det_sast_polygon + if self.det_sast_polygon: + postprocess_params["sample_pts_num"] = 6 + postprocess_params["expand_scale"] = 1.2 + postprocess_params["shrink_ratio_of_width"] = 0.2 + else: + postprocess_params["sample_pts_num"] = 2 + postprocess_params["expand_scale"] = 1.0 + postprocess_params["shrink_ratio_of_width"] = 0.3 + elif self.det_algorithm == "PSE": + postprocess_params['name'] = 'PSEPostProcess' + postprocess_params["thresh"] = args.det_pse_thresh + postprocess_params["box_thresh"] = args.det_pse_box_thresh + postprocess_params["min_area"] = args.det_pse_min_area + postprocess_params["box_type"] = args.det_pse_box_type + postprocess_params["scale"] = args.det_pse_scale + self.det_pse_box_type = args.det_pse_box_type + elif self.det_algorithm == "FCE": + pre_process_list[0] = { + 'DetResizeForTest': { + 'rescale_img': [1080, 736] + } + } + postprocess_params['name'] = 'FCEPostProcess' + postprocess_params["scales"] = args.scales + postprocess_params["alpha"] = args.alpha + postprocess_params["beta"] = args.beta + postprocess_params["fourier_degree"] = args.fourier_degree + postprocess_params["box_type"] = args.det_fce_box_type + else: + logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) + sys.exit(0) + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( + args, 'det', logger) + + if self.use_onnx: + img_h, img_w = self.input_tensor.shape[2:] + if img_h is not None and img_w is not None and img_h > 0 and img_w > 0: + pre_process_list[0] = { + 'DetResizeForTest': { + 'image_shape': [img_h, img_w] + } + } + self.preprocess_op = create_operators(pre_process_list) + + if args.benchmark: + import auto_log + pid = os.getpid() + gpu_id = utility.get_infer_gpuid() + self.autolog = auto_log.AutoLogger( + model_name="det", + model_precision=args.precision, + batch_size=1, + data_shape="dynamic", + save_path=None, + inference_config=self.config, + pids=pid, + process_name=None, + gpu_ids=gpu_id if args.use_gpu else None, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=2, + logger=logger) + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) + diff = np.diff(np.array(tmp), axis=1) + rect[1] = tmp[np.argmin(diff)] + rect[3] = tmp[np.argmax(diff)] + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + + st = time.time() + + if self.args.benchmark: + self.autolog.times.start() + + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + + if self.args.benchmark: + self.autolog.times.stamp() + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + else: + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.args.benchmark: + self.autolog.times.stamp() + + preds = {} + if self.det_algorithm == "EAST": + preds['f_geo'] = outputs[0] + preds['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + preds['f_border'] = outputs[0] + preds['f_score'] = outputs[1] + preds['f_tco'] = outputs[2] + preds['f_tvo'] = outputs[3] + elif self.det_algorithm in ['DB', 'PSE']: + preds['maps'] = outputs[0] + elif self.det_algorithm == 'FCE': + for i, output in enumerate(outputs): + preds['level_{}'.format(i)] = output + else: + raise NotImplementedError + + #self.predictor.try_shrink_memory() + post_result = self.postprocess_op(preds, shape_list) + dt_boxes = post_result[0]['points'] + if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( + self.det_algorithm in ["PSE", "FCE"] and + self.postprocess_op.box_type == 'poly'): + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + + if self.args.benchmark: + self.autolog.times.end(stamp=True) + et = time.time() + return dt_boxes, et - st + + +if __name__ == "__main__": + from ppocr.metrics.eval_det_iou import DetectionIoUEvaluator + evaluator = DetectionIoUEvaluator() + args = utility.parse_args() + + # image_file_list = get_image_file_list(args.image_dir) + def _check_image_file(path): + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + return any([path.lower().endswith(e) for e in img_end]) + + def get_image_file_list_from_txt(img_file): + imgs_lists = [] + label_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + root_dir = img_file.split('/')[0] + with open(img_file, 'r') as f: + lines = f.readlines() + for line in lines: + line = line.replace('\n', '').split('\t') + file_path, label = line[0], line[1] + file_path = os.path.join(root_dir, file_path) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + label_lists.append(label) + + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists, label_lists + + image_file_list, label_list = get_image_file_list_from_txt(args.image_dir) + + text_detector = TextDetector(args) + count = 0 + total_time = 0 + draw_img_save = "./inference_results" + + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(2): + res = text_detector(img) + + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + save_results = [] + results = [] + for idx in range(len(image_file_list)): + image_file = image_file_list[idx] + label = json.loads(label_list[idx]) + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st + if count > 0: + total_time += elapse + count += 1 + save_pred = os.path.basename(image_file) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + save_results.append(save_pred) + + # for eval + gt_info_list = [] + det_info_list = [] + for dt_box in dt_boxes: + det_info = { + 'points': np.array( + dt_box, dtype=np.float32), + 'text': '' + } + det_info_list.append(det_info) + for lab in label: + gt_info = { + 'points': np.array( + lab['points'], dtype=np.float32), + 'text': '', + 'ignore': False + } + gt_info_list.append(gt_info) + result = evaluator.evaluate_image(gt_info_list, det_info_list) + results.append(result) + + metrics = evaluator.combine_results(results) + print('predict det eval on ', args.image_dir) + print('metrics: ', metrics) + +# logger.info(save_pred) +# logger.info("The predict time of {}: {}".format(image_file, elapse)) +# src_im = utility.draw_text_det_res(dt_boxes, image_file) +# img_name_pure = os.path.split(image_file)[-1] +# img_path = os.path.join(draw_img_save, +# "det_res_{}".format(img_name_pure)) +# cv2.imwrite(img_path, src_im) +# logger.info("The visualized image saved in {}".format(img_path)) + +# with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: +# f.writelines(save_results) +# f.close() +# if args.benchmark: +# text_detector.autolog.report() diff --git a/tools/infer/predict_rec_eval.py b/tools/infer/predict_rec_eval.py new file mode 100755 index 00000000..3150d11d --- /dev/null +++ b/tools/infer/predict_rec_eval.py @@ -0,0 +1,534 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +from PIL import Image +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import math +import time +import traceback +import paddle + +import tools.infer.utility as utility +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif + +logger = get_logger() + + +class TextRecognizer(object): + def __init__(self, args): + self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] + self.rec_batch_num = args.rec_batch_num + self.rec_algorithm = args.rec_algorithm + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + if self.rec_algorithm == "SRN": + postprocess_params = { + 'name': 'SRNLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == "RARE": + postprocess_params = { + 'name': 'AttnLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == 'NRTR': + postprocess_params = { + 'name': 'NRTRLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == "SAR": + postprocess_params = { + 'name': 'SARLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == 'ViTSTR': + postprocess_params = { + 'name': 'ViTSTRLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + elif self.rec_algorithm == 'ABINet': + postprocess_params = { + 'name': 'ABINetLabelDecode', + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 'rec', logger) + self.benchmark = args.benchmark + self.use_onnx = args.use_onnx + if args.benchmark: + import auto_log + pid = os.getpid() + gpu_id = utility.get_infer_gpuid() + self.autolog = auto_log.AutoLogger( + model_name="rec", + model_precision=args.precision, + batch_size=args.rec_batch_num, + data_shape="dynamic", + save_path=None, #args.save_log_path, + inference_config=self.config, + pids=pid, + process_name=None, + gpu_ids=gpu_id if args.use_gpu else None, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=0, + logger=logger) + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR': + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # return padding_im + image_pil = Image.fromarray(np.uint8(img)) + if self.rec_algorithm == 'ViTSTR': + img = image_pil.resize([imgW, imgH], Image.BICUBIC) + else: + img = image_pil.resize([imgW, imgH], Image.ANTIALIAS) + img = np.array(img) + norm_img = np.expand_dims(img, -1) + norm_img = norm_img.transpose((2, 0, 1)) + if self.rec_algorithm == 'ViTSTR': + norm_img = norm_img.astype(np.float32) / 255. + else: + norm_img = norm_img.astype(np.float32) / 128. - 1. + return norm_img + + assert imgC == img.shape[2] + imgW = int((imgH * max_wh_ratio)) + if self.use_onnx: + w = self.input_tensor.shape[3:][0] + if w is not None and w > 0: + imgW = w + + h, w = img.shape[:2] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + if self.rec_algorithm == 'RARE': + if resized_w > self.rec_image_shape[2]: + resized_w = self.rec_image_shape[2] + imgW = self.rec_image_shape[2] + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, img, image_shape, num_heads, max_text_length): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + encoder_word_pos = encoder_word_pos.astype(np.int64) + gsrm_word_pos = gsrm_word_pos.astype(np.int64) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + + def resize_norm_img_sar(self, img, image_shape, + width_downsample_ratio=0.25): + imgC, imgH, imgW_min, imgW_max = image_shape + h = img.shape[0] + w = img.shape[1] + valid_ratio = 1.0 + # make sure new_width is an integral multiple of width_divisor. + width_divisor = int(1 / width_downsample_ratio) + # resize + ratio = w / float(h) + resize_w = math.ceil(imgH * ratio) + if resize_w % width_divisor != 0: + resize_w = round(resize_w / width_divisor) * width_divisor + if imgW_min is not None: + resize_w = max(imgW_min, resize_w) + if imgW_max is not None: + valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) + resize_w = min(imgW_max, resize_w) + resized_image = cv2.resize(img, (resize_w, imgH)) + resized_image = resized_image.astype('float32') + # norm + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resize_shape = resized_image.shape + padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) + padding_im[:, :, 0:resize_w] = resized_image + pad_shape = padding_im.shape + + return padding_im, resize_shape, pad_shape, valid_ratio + + def resize_norm_img_svtr(self, img, image_shape): + + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image + + def resize_norm_img_abinet(self, img, image_shape): + + imgC, imgH, imgW = image_shape + + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image / 255. + + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + resized_image = ( + resized_image - mean[None, None, ...]) / std[None, None, ...] + resized_image = resized_image.transpose((2, 0, 1)) + resized_image = resized_image.astype('float32') + + return resized_image + + def __call__(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + st = time.time() + if self.benchmark: + self.autolog.times.start() + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + imgC, imgH, imgW = self.rec_image_shape + max_wh_ratio = imgW / imgH + # max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + + if self.rec_algorithm == "SAR": + norm_img, _, _, valid_ratio = self.resize_norm_img_sar( + img_list[indices[ino]], self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + valid_ratio = np.expand_dims(valid_ratio, axis=0) + valid_ratios = [] + valid_ratios.append(valid_ratio) + norm_img_batch.append(norm_img) + elif self.rec_algorithm == "SRN": + norm_img = self.process_image_srn( + img_list[indices[ino]], self.rec_image_shape, 8, 25) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) + elif self.rec_algorithm == "SVTR": + norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], + self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + elif self.rec_algorithm == "ABINet": + norm_img = self.resize_norm_img_abinet( + img_list[indices[ino]], self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + if self.benchmark: + self.autolog.times.stamp() + + if self.rec_algorithm == "SRN": + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, + encoder_word_pos_list, + gsrm_word_pos_list, + gsrm_slf_attn_bias1_list, + gsrm_slf_attn_bias2_list, + ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = {"predict": outputs[2]} + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = {"predict": outputs[2]} + elif self.rec_algorithm == "SAR": + valid_ratios = np.concatenate(valid_ratios) + inputs = [ + norm_img_batch, + valid_ratios, + ] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle( + input_names[i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + preds = outputs[0] + else: + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, + input_dict) + preds = outputs[0] + else: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.benchmark: + self.autolog.times.stamp() + if len(outputs) != 1: + preds = outputs + else: + preds = outputs[0] + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] + if self.benchmark: + self.autolog.times.end(stamp=True) + return rec_res, time.time() - st + + +def main(args): + # image_file_list = get_image_file_list(args.image_dir) + + def _check_image_file(path): + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + return any([path.lower().endswith(e) for e in img_end]) + + def get_image_file_list_from_txt(img_file): + imgs_lists = [] + label_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} + root_dir = img_file.split('/')[0] + with open(img_file, 'r') as f: + lines = f.readlines() + for line in lines: + line = line.replace('\n', '').split('\t') + file_path, label = line[0], line[1] + file_path = os.path.join(root_dir, file_path) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) + label_lists.append(label) + + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists, label_lists + + image_file_list, label_list = get_image_file_list_from_txt(args.image_dir) + + text_recognizer = TextRecognizer(args) + valid_image_file_list = [] + img_list = [] + + logger.info( + "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', " + "if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320" + ) + # warmup 2 times + if args.warmup: + img = np.random.uniform(0, 255, [48, 320, 3]).astype(np.uint8) + for i in range(2): + res = text_recognizer([img] * int(args.rec_batch_num)) + + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(img) + + try: + rec_res, _ = text_recognizer(img_list) + except Exception as E: + logger.info(traceback.format_exc()) + logger.info(E) + exit() + correct_num = 0 + for ino in range(len(img_list)): + pred = rec_res[ino][0] + gt = label_list[ino] + if pred == gt: + correct_num += 1 + acc = correct_num * 1.0 / len(img_list) + print('predict rec eval on ', args.image_dir) + print('acc: ', acc) + + # for debug bad case + bad_case_lines = [] + for ino in range(len(img_list)): + pred = rec_res[ino][0] + gt = label_list[ino] + if pred != gt and len(gt) <= 25: + bad_case = valid_image_file_list[ + ino] + '\t' + 'pred:' + pred + '\t' + 'gt:' + gt + '\n' + bad_case_lines.append(bad_case) + + with open('bad_case_hwdb2.txt', 'a+') as f: + f.writelines(bad_case_lines) + # end debug case + + if args.benchmark: + text_recognizer.autolog.report() + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/program.py b/tools/program.py index 0fa0e609..dd8037e4 100755 --- a/tools/program.py +++ b/tools/program.py @@ -154,13 +154,14 @@ def check_xpu(use_xpu): except Exception as e: pass + def to_float32(preds): if isinstance(preds, dict): for k in preds: if isinstance(preds[k], dict) or isinstance(preds[k], list): preds[k] = to_float32(preds[k]) else: - preds[k] = preds[k].astype(paddle.float32) + preds[k] = paddle.to_tensor(preds[k], dtype='float32') elif isinstance(preds, list): for k in range(len(preds)): if isinstance(preds[k], dict): @@ -168,11 +169,12 @@ def to_float32(preds): elif isinstance(preds[k], list): preds[k] = to_float32(preds[k]) else: - preds[k] = preds[k].astype(paddle.float32) + preds[k] = paddle.to_tensor(preds[k], dtype='float32') else: - preds = preds.astype(paddle.float32) + preds[k] = paddle.to_tensor(preds[k], dtype='float32') return preds + def train(config, train_dataloader, valid_dataloader, -- GitLab