From f96b873aa4eb5bc0c167d1ee485725b59ec5e785 Mon Sep 17 00:00:00 2001 From: licx Date: Mon, 17 Aug 2020 20:29:28 +0800 Subject: [PATCH] modify infer tools for sast --- configs/det/det_sast_icdar15_reader.yml | 6 ++---- configs/det/det_sast_totaltext_reader.yml | 2 +- ppocr/modeling/heads/det_sast_head.py | 4 ++-- tools/infer/predict_det.py | 16 +++++++++++++++ tools/infer/utility.py | 7 +++++++ tools/infer_det.py | 25 +++++++++++++++++++++-- 6 files changed, 51 insertions(+), 9 deletions(-) diff --git a/configs/det/det_sast_icdar15_reader.yml b/configs/det/det_sast_icdar15_reader.yml index 1fdea875..cfa5b8b6 100644 --- a/configs/det/det_sast_icdar15_reader.yml +++ b/configs/det/det_sast_icdar15_reader.yml @@ -20,7 +20,5 @@ EvalReader: TestReader: reader_function: ppocr.data.det.dataset_traversal,EvalTestReader process_function: ppocr.data.det.sast_process,SASTProcessTest - infer_img: - img_set_dir: ./train_data/icdar2015/text_localization/ - label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt - do_eval: True + infer_img: ./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg + max_side_len: 1536 diff --git a/configs/det/det_sast_totaltext_reader.yml b/configs/det/det_sast_totaltext_reader.yml index 74c320fc..c1a3d251 100644 --- a/configs/det/det_sast_totaltext_reader.yml +++ b/configs/det/det_sast_totaltext_reader.yml @@ -20,5 +20,5 @@ EvalReader: TestReader: reader_function: ppocr.data.det.dataset_traversal,EvalTestReader process_function: ppocr.data.det.sast_process,SASTProcessTest - infer_img: + infer_img: ./train_data/afs/total_text/Images/Test/img623.jpg max_side_len: 768 diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py index b5e19b84..0097913d 100644 --- a/ppocr/modeling/heads/det_sast_head.py +++ b/ppocr/modeling/heads/det_sast_head.py @@ -49,7 +49,7 @@ class SASTHead(object): for i in range(4): if i == 0: g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0') - print("g[{}] shape: {}".format(i, g[i].shape)) + #print("g[{}] shape: {}".format(i, g[i].shape)) else: g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i]) g[i] = fluid.layers.relu(g[i]) @@ -58,7 +58,7 @@ class SASTHead(object): g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i) g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i) - print("g[{}] shape: {}".format(i, g[i].shape)) + #print("g[{}] shape: {}".format(i, g[i].shape)) g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4]) g[4] = fluid.layers.relu(g[4]) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 75644aeb..74ffb0b5 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.utils.utility import get_image_file_list, check_and_read_gif import cv2 +from ppocr.data.det.sast_process import SASTProcessTest from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.db_process import DBProcessTest from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.east_postprocess import EASTPostPocess +from ppocr.postprocess.sast_postprocess import SASTPostProcess import copy import numpy as np import math @@ -52,6 +54,14 @@ class TextDetector(object): postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh self.postprocess_op = EASTPostPocess(postprocess_params) + elif self.det_algorithm == "SAST": + self.preprocess_op = SASTProcessTest(preprocess_params) + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + postprocess_params["sample_pts_num"] = args.det_sast_sample_pts_num + postprocess_params["expand_scale"] = args.det_sast_expand_scale + postprocess_params["shrink_ratio_of_width"] = args.det_sast_shrink_ratio_of_width + self.postprocess_op = SASTPostProcess(postprocess_params) else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -120,8 +130,14 @@ class TextDetector(object): if self.det_algorithm == "EAST": outs_dict['f_geo'] = outputs[0] outs_dict['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + outs_dict['f_border'] = outputs[0] + outs_dict['f_score'] = outputs[1] + outs_dict['f_tco'] = outputs[2] + outs_dict['f_tvo'] = outputs[3] else: outs_dict['maps'] = outputs[0] + dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes = dt_boxes_list[0] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b0a0ec1f..f31b0f65 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -53,6 +53,13 @@ def parse_args(): parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + #SAST parmas + parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) + parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) + parser.add_argument("--det_sast_sample_pts_num", type=float, default=2) + parser.add_argument("--det_sast_expand_scale", type=float, default=1.0) + parser.add_argument("--det_sast_shrink_ratio_of_width", type=float, default=0.3) + #params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) diff --git a/tools/infer_det.py b/tools/infer_det.py index a8b49b6b..ab0af6f4 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -66,6 +66,25 @@ def draw_det_res(dt_boxes, config, img, img_name): cv2.imwrite(save_path, src_im) logger.info("The detected Image saved in {}".format(save_path)) +def gen_im_detection(src_im, detections): + """ + Generate image with detection results. + """ + im_detection = src_im.copy() + + h, w, _ = im_detection.shape + thickness = int(max((h + w) / 2000, 1)) + + for poly in detections: + # Draw the first point + cv2.putText(im_detection, '0', org=(int(poly[0, 0]), int(poly[0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=thickness, color=(255, 0, 0), + thickness=thickness) + + cv2.polylines(im_detection, np.array(poly).reshape((1, -1, 2)).astype(np.int32), isClosed=True, + color=(0, 0, 255), thickness=thickness) + + return im_detection def main(): config = program.load_config(FLAGS.config) @@ -134,8 +153,10 @@ def main(): dic = {'f_score': outs[0], 'f_geo': outs[1]} elif config['Global']['algorithm'] == 'DB': dic = {'maps': outs[0]} + elif config['Global']['algorithm'] == 'SAST': + dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]} else: - raise Exception("only support algorithm: ['EAST', 'DB']") + raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']") dt_boxes_list = postprocess(dic, ratio_list) for ino in range(img_num): dt_boxes = dt_boxes_list[ino] @@ -149,7 +170,7 @@ def main(): fout.write(otstr.encode()) src_img = cv2.imread(img_name) draw_det_res(dt_boxes, config, src_img, img_name) - + logger.info("success!") -- GitLab