diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index f5e358e95e5b1c9a0134c473877f1e53047f09db..3c14011a24cf5afcecc5edd5a54e395a0f171f53 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -39,6 +39,7 @@ class TextClassifier(object): self.cls_batch_num = args.rec_batch_num self.label_list = args.label_list self.use_zero_copy_run = args.use_zero_copy_run + self.cls_thresh = args.cls_thresh def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape @@ -110,7 +111,7 @@ class TextClassifier(object): score = prob_out[rno][label_idx] label = self.label_list[label_idx] cls_res[indices[beg_img_no + rno]] = [label, score] - if '180' in label and score > 0.9999: + if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) return img_list, cls_res, predict_time diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 92212afd5f3e16601939d0ca7882fb3b90c3a9ac..50d934efe91faa2956e63a2344c8b6b6090e4f7a 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -78,6 +78,7 @@ def parse_args(): parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--cls_batch_num", type=int, default=30) + parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)