提交 cf2a4833 编写于 作者: W WenmuZhou

分类支持传参置信度

上级 06430c93
...@@ -39,6 +39,7 @@ class TextClassifier(object): ...@@ -39,6 +39,7 @@ class TextClassifier(object):
self.cls_batch_num = args.rec_batch_num self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list self.label_list = args.label_list
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
self.cls_thresh = args.cls_thresh
def resize_norm_img(self, img): def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape imgC, imgH, imgW = self.cls_image_shape
...@@ -110,7 +111,7 @@ class TextClassifier(object): ...@@ -110,7 +111,7 @@ class TextClassifier(object):
score = prob_out[rno][label_idx] score = prob_out[rno][label_idx]
label = self.label_list[label_idx] label = self.label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score] cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > 0.9999: if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1) img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res, predict_time return img_list, cls_res, predict_time
......
...@@ -78,6 +78,7 @@ def parse_args(): ...@@ -78,6 +78,7 @@ def parse_args():
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=30) parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False) parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册