From 0c3b5d8e763e1c15a21f5c1d7cffea8fc44857db Mon Sep 17 00:00:00 2001 From: licx Date: Tue, 18 Aug 2020 20:32:00 +0800 Subject: [PATCH] fix bug in predict_det for sast & update docs --- doc/doc_ch/inference.md | 6 +++++- tools/infer/predict_det.py | 19 ++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index 52aa4c9f..28671d17 100644 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -296,7 +296,11 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model ### 2. 其他模型推理 -如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令: +如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型。 + +**注意:由于检测框矫正逻辑的局限性,SAST弯曲文本检测模型(即,使用参数`--det_sast_polygon=True`时)暂时无法用来模型串联。** + +下面给出基于EAST文本检测和STAR-Net文本识别执行命令: ``` python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index b556c7f0..af1d60c3 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -58,7 +58,8 @@ class TextDetector(object): self.preprocess_op = SASTProcessTest(preprocess_params) postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh - if args.det_sast_polygon: + self.det_sast_polygon = args.det_sast_polygon + if self.det_sast_polygon: postprocess_params["sample_pts_num"] = 6 postprocess_params["expand_scale"] = 1.2 postprocess_params["shrink_ratio_of_width"] = 0.2 @@ -99,7 +100,7 @@ class TextDetector(object): return rect def clip_det_res(self, points, img_height, img_width): - for pno in range(4): + for pno in range(points.shape[0]): points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) return points @@ -118,6 +119,15 @@ class TextDetector(object): dt_boxes = np.array(dt_boxes_new) return dt_boxes + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + def __call__(self, img): ori_im = img.copy() im, ratio_list = self.preprocess_op(img) @@ -145,7 +155,10 @@ class TextDetector(object): dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes = dt_boxes_list[0] -# dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) elapse = time.time() - starttime return dt_boxes, elapse -- GitLab