From c7819af417a89c188f6f7ab3a1874a4f60bf11c8 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 16 Sep 2020 10:43:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E8=BE=93=E5=87=BA=E7=9A=84scores?= =?UTF-8?q?=E5=92=8Clabel=E8=BF=9B=E8=A1=8C=E9=A1=BA=E5=BA=8F=E5=88=A4?= =?UTF-8?q?=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/predict_cls.py | 2 ++ tools/infer_cls.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 5c54224e..f5e358e9 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -100,6 +100,8 @@ class TextClassifier(object): prob_out = self.output_tensors[0].copy_to_cpu() label_out = self.output_tensors[1].copy_to_cpu() + if len(label_out.shape) != 1: + prob_out, label_out = label_out, prob_out elapse = time.time() - starttime predict_time += elapse diff --git a/tools/infer_cls.py b/tools/infer_cls.py index 1f78cdf9..aebdc076 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np import os import sys + __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) @@ -40,6 +41,7 @@ set_paddle_flags( import tools.program as program from paddle import fluid from ppocr.utils.utility import initial_logger + logger = initial_logger() from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model @@ -87,6 +89,8 @@ def main(): return_numpy=False) scores = np.array(predict[0]) label = np.array(predict[1]) + if len(label.shape) != 1: + label, scores = scores, label logger.info('\t scores: {}'.format(scores)) logger.info('\t label: {}'.format(label)) # save for inference model -- GitLab