From cf2a483369f66ad22babd88949e32dc4f046a59e Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Fri, 18 Sep 2020 11:29:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E7=B1=BB=E6=94=AF=E6=8C=81=E4=BC=A0?= =?UTF-8?q?=E5=8F=82=E7=BD=AE=E4=BF=A1=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/predict_cls.py | 3 ++- tools/infer/utility.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index f5e358e9..3c14011a 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -39,6 +39,7 @@ class TextClassifier(object): self.cls_batch_num = args.rec_batch_num self.label_list = args.label_list self.use_zero_copy_run = args.use_zero_copy_run + self.cls_thresh = args.cls_thresh def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape @@ -110,7 +111,7 @@ class TextClassifier(object): score = prob_out[rno][label_idx] label = self.label_list[label_idx] cls_res[indices[beg_img_no + rno]] = [label, score] - if '180' in label and score > 0.9999: + if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) return img_list, cls_res, predict_time diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 92212afd..50d934ef 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -78,6 +78,7 @@ def parse_args(): parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--cls_batch_num", type=int, default=30) + parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--use_zero_copy_run", type=str2bool, default=False) -- GitLab