未验证 提交 aecbbe1c 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #552 from MissPenguin/develop

modify infer tools for sast
...@@ -20,7 +20,5 @@ EvalReader: ...@@ -20,7 +20,5 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest process_function: ppocr.data.det.sast_process,SASTProcessTest
infer_img: infer_img: ./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg
img_set_dir: ./train_data/icdar2015/text_localization/ max_side_len: 1536
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
do_eval: True
...@@ -20,5 +20,5 @@ EvalReader: ...@@ -20,5 +20,5 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest process_function: ppocr.data.det.sast_process,SASTProcessTest
infer_img: infer_img: ./train_data/afs/total_text/Images/Test/img623.jpg
max_side_len: 768 max_side_len: 768
...@@ -49,7 +49,7 @@ class SASTHead(object): ...@@ -49,7 +49,7 @@ class SASTHead(object):
for i in range(4): for i in range(4):
if i == 0: if i == 0:
g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0') g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0')
print("g[{}] shape: {}".format(i, g[i].shape)) #print("g[{}] shape: {}".format(i, g[i].shape))
else: else:
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i]) g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
g[i] = fluid.layers.relu(g[i]) g[i] = fluid.layers.relu(g[i])
...@@ -58,7 +58,7 @@ class SASTHead(object): ...@@ -58,7 +58,7 @@ class SASTHead(object):
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i) filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i)
g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i) g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i)
print("g[{}] shape: {}".format(i, g[i].shape)) #print("g[{}] shape: {}".format(i, g[i].shape))
g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4]) g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4])
g[4] = fluid.layers.relu(g[4]) g[4] = fluid.layers.relu(g[4])
......
...@@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger ...@@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2 import cv2
from ppocr.data.det.sast_process import SASTProcessTest
from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.east_process import EASTProcessTest
from ppocr.data.det.db_process import DBProcessTest from ppocr.data.det.db_process import DBProcessTest
from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.db_postprocess import DBPostProcess
from ppocr.postprocess.east_postprocess import EASTPostPocess from ppocr.postprocess.east_postprocess import EASTPostPocess
from ppocr.postprocess.sast_postprocess import SASTPostProcess
import copy import copy
import numpy as np import numpy as np
import math import math
...@@ -52,6 +54,14 @@ class TextDetector(object): ...@@ -52,6 +54,14 @@ class TextDetector(object):
postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh
self.postprocess_op = EASTPostPocess(postprocess_params) self.postprocess_op = EASTPostPocess(postprocess_params)
elif self.det_algorithm == "SAST":
self.preprocess_op = SASTProcessTest(preprocess_params)
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
postprocess_params["sample_pts_num"] = args.det_sast_sample_pts_num
postprocess_params["expand_scale"] = args.det_sast_expand_scale
postprocess_params["shrink_ratio_of_width"] = args.det_sast_shrink_ratio_of_width
self.postprocess_op = SASTPostProcess(postprocess_params)
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -120,8 +130,14 @@ class TextDetector(object): ...@@ -120,8 +130,14 @@ class TextDetector(object):
if self.det_algorithm == "EAST": if self.det_algorithm == "EAST":
outs_dict['f_geo'] = outputs[0] outs_dict['f_geo'] = outputs[0]
outs_dict['f_score'] = outputs[1] outs_dict['f_score'] = outputs[1]
elif self.det_algorithm == 'SAST':
outs_dict['f_border'] = outputs[0]
outs_dict['f_score'] = outputs[1]
outs_dict['f_tco'] = outputs[2]
outs_dict['f_tvo'] = outputs[3]
else: else:
outs_dict['maps'] = outputs[0] outs_dict['maps'] = outputs[0]
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0] dt_boxes = dt_boxes_list[0]
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
......
...@@ -53,6 +53,13 @@ def parse_args(): ...@@ -53,6 +53,13 @@ def parse_args():
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
#SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_sample_pts_num", type=float, default=2)
parser.add_argument("--det_sast_expand_scale", type=float, default=1.0)
parser.add_argument("--det_sast_shrink_ratio_of_width", type=float, default=0.3)
#params for text recognizer #params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
......
...@@ -134,8 +134,10 @@ def main(): ...@@ -134,8 +134,10 @@ def main():
dic = {'f_score': outs[0], 'f_geo': outs[1]} dic = {'f_score': outs[0], 'f_geo': outs[1]}
elif config['Global']['algorithm'] == 'DB': elif config['Global']['algorithm'] == 'DB':
dic = {'maps': outs[0]} dic = {'maps': outs[0]}
elif config['Global']['algorithm'] == 'SAST':
dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
else: else:
raise Exception("only support algorithm: ['EAST', 'DB']") raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
dt_boxes_list = postprocess(dic, ratio_list) dt_boxes_list = postprocess(dic, ratio_list)
for ino in range(img_num): for ino in range(img_num):
dt_boxes = dt_boxes_list[ino] dt_boxes = dt_boxes_list[ino]
...@@ -149,7 +151,7 @@ def main(): ...@@ -149,7 +151,7 @@ def main():
fout.write(otstr.encode()) fout.write(otstr.encode())
src_img = cv2.imread(img_name) src_img = cv2.imread(img_name)
draw_det_res(dt_boxes, config, src_img, img_name) draw_det_res(dt_boxes, config, src_img, img_name)
logger.info("success!") logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册