diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 52c225d2b3913cf8c0dc88abcc07f7ccfd3cc914..08441d054d6bc5f2cabccb882f67047b7194899e 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -67,6 +67,7 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode + postprocess_params["use_polygon"] = args.det_use_polygon elif self.det_algorithm == "DB++": postprocess_params['name'] = 'DBPostProcess' postprocess_params["thresh"] = args.det_db_thresh @@ -75,6 +76,7 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode + postprocess_params["use_polygon"] = args.det_use_polygon pre_process_list[1] = { 'NormalizeImage': { 'std': [1.0, 1.0, 1.0], @@ -204,6 +206,8 @@ class TextDetector(object): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: + if type(box) is list: + box = np.array(box) box = self.clip_det_res(box, img_height, img_width) dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) @@ -262,13 +266,15 @@ class TextDetector(object): else: raise NotImplementedError - #self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] + print("det_boxes", dt_boxes) if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( self.det_algorithm in ["PSE", "FCE", "CT"] and self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + elif 'DB' in self.det_algorithm and self.postprocess_op.use_polygon is True: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index e555dbec1b314510aaaf6b31f1b35bf60fefa98e..30ae712d1772503e60b65afbaffe3151bb3be9f3 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -58,6 +58,7 @@ def init_args(): parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=str2bool, default=False) parser.add_argument("--det_db_score_mode", type=str, default="fast") + parser.add_argument("--det_use_polygon", type=str2bool, default=False) # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)