diff --git a/configs/det/det_sast_icdar15_reader.yml b/configs/det/det_sast_icdar15_reader.yml index 1fdea8756c3d80b291aba4afbb8c7d340654b6cb..cfa5b8b64052d421e63e373cec82f4d735238a8c 100644 --- a/configs/det/det_sast_icdar15_reader.yml +++ b/configs/det/det_sast_icdar15_reader.yml @@ -20,7 +20,5 @@ EvalReader: TestReader: reader_function: ppocr.data.det.dataset_traversal,EvalTestReader process_function: ppocr.data.det.sast_process,SASTProcessTest - infer_img: - img_set_dir: ./train_data/icdar2015/text_localization/ - label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt - do_eval: True + infer_img: ./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg + max_side_len: 1536 diff --git a/configs/det/det_sast_totaltext_reader.yml b/configs/det/det_sast_totaltext_reader.yml index 74c320fcc5e1c8340d0439e350e2a4e88824b16d..c1a3d25133a24556ad754da9294bd41533b6235c 100644 --- a/configs/det/det_sast_totaltext_reader.yml +++ b/configs/det/det_sast_totaltext_reader.yml @@ -20,5 +20,5 @@ EvalReader: TestReader: reader_function: ppocr.data.det.dataset_traversal,EvalTestReader process_function: ppocr.data.det.sast_process,SASTProcessTest - infer_img: + infer_img: ./train_data/afs/total_text/Images/Test/img623.jpg max_side_len: 768 diff --git a/ppocr/modeling/heads/det_sast_head.py b/ppocr/modeling/heads/det_sast_head.py index b5e19b844abda003d65a1f95026685cfe0cfffd6..0097913dd7e08c76c45064940416e7c9ffc32f26 100644 --- a/ppocr/modeling/heads/det_sast_head.py +++ b/ppocr/modeling/heads/det_sast_head.py @@ -49,7 +49,7 @@ class SASTHead(object): for i in range(4): if i == 0: g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0') - print("g[{}] shape: {}".format(i, g[i].shape)) + #print("g[{}] shape: {}".format(i, g[i].shape)) else: g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i]) g[i] = fluid.layers.relu(g[i]) @@ -58,7 +58,7 @@ class SASTHead(object): g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i) g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i) - print("g[{}] shape: {}".format(i, g[i].shape)) + #print("g[{}] shape: {}".format(i, g[i].shape)) g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4]) g[4] = fluid.layers.relu(g[4]) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 75644aeb990ab95edb51f2809bb8cc8fbdf3e2be..74ffb0b56d8ae6810a15c4d914a3975448a5e7ee 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.utils.utility import get_image_file_list, check_and_read_gif import cv2 +from ppocr.data.det.sast_process import SASTProcessTest from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.db_process import DBProcessTest from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.east_postprocess import EASTPostPocess +from ppocr.postprocess.sast_postprocess import SASTPostProcess import copy import numpy as np import math @@ -52,6 +54,14 @@ class TextDetector(object): postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh self.postprocess_op = EASTPostPocess(postprocess_params) + elif self.det_algorithm == "SAST": + self.preprocess_op = SASTProcessTest(preprocess_params) + postprocess_params["score_thresh"] = args.det_sast_score_thresh + postprocess_params["nms_thresh"] = args.det_sast_nms_thresh + postprocess_params["sample_pts_num"] = args.det_sast_sample_pts_num + postprocess_params["expand_scale"] = args.det_sast_expand_scale + postprocess_params["shrink_ratio_of_width"] = args.det_sast_shrink_ratio_of_width + self.postprocess_op = SASTPostProcess(postprocess_params) else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -120,8 +130,14 @@ class TextDetector(object): if self.det_algorithm == "EAST": outs_dict['f_geo'] = outputs[0] outs_dict['f_score'] = outputs[1] + elif self.det_algorithm == 'SAST': + outs_dict['f_border'] = outputs[0] + outs_dict['f_score'] = outputs[1] + outs_dict['f_tco'] = outputs[2] + outs_dict['f_tvo'] = outputs[3] else: outs_dict['maps'] = outputs[0] + dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes = dt_boxes_list[0] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b0a0ec1f3037b35d22212cb6fb3555746edc2e99..f31b0f652cfd74cd60118c064af21ab5cbe2697a 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -53,6 +53,13 @@ def parse_args(): parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + #SAST parmas + parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) + parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) + parser.add_argument("--det_sast_sample_pts_num", type=float, default=2) + parser.add_argument("--det_sast_expand_scale", type=float, default=1.0) + parser.add_argument("--det_sast_shrink_ratio_of_width", type=float, default=0.3) + #params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) diff --git a/tools/infer_det.py b/tools/infer_det.py index a8b49b6b075ba509e17c37c8d1f05dee9822edec..93c24dd24a99a5f696c0515b0daa74a536d18dd4 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -134,8 +134,10 @@ def main(): dic = {'f_score': outs[0], 'f_geo': outs[1]} elif config['Global']['algorithm'] == 'DB': dic = {'maps': outs[0]} + elif config['Global']['algorithm'] == 'SAST': + dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]} else: - raise Exception("only support algorithm: ['EAST', 'DB']") + raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']") dt_boxes_list = postprocess(dic, ratio_list) for ino in range(img_num): dt_boxes = dt_boxes_list[ino] @@ -149,7 +151,7 @@ def main(): fout.write(otstr.encode()) src_img = cv2.imread(img_name) draw_det_res(dt_boxes, config, src_img, img_name) - + logger.info("success!")