From 633d0552851d41ae77c2f28ba83fda350648c341 Mon Sep 17 00:00:00 2001 From: Double_V Date: Thu, 7 Apr 2022 17:26:07 +0800 Subject: [PATCH] fix bug for PPOCRV2_rec_pact export model and inference (#5903) * fix bug for rec distill pact infer * Update rec_postprocess.py --- deploy/slim/quantization/export_model.py | 1 + ppocr/postprocess/rec_postprocess.py | 2 +- tools/infer/utility.py | 16 +++++++++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 0cb86108..79d95d04 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -127,6 +127,7 @@ def main(): arch_config = config["Architecture"] if arch_config["algorithm"] in ["Distillation", ]: # distillation model for idx, name in enumerate(model.model_name_list): + model.model_list[idx].eval() sub_model_save_path = os.path.join(save_path, name, "inference") export_single_model(quanter, model.model_list[idx], infer_shape, sub_model_save_path, logger) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index caaa2948..cddc263a 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -87,7 +87,7 @@ class CTCLabelDecode(BaseRecLabelDecode): use_space_char) def __call__(self, preds, label=None, *args, **kwargs): - if isinstance(preds, tuple): + if isinstance(preds, (tuple, list)): preds = preds[-1] if isinstance(preds, paddle.Tensor): preds = preds.numpy() diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 7b7b81e3..37122f0a 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -305,12 +305,22 @@ def create_predictor(args, mode, logger): input_names = predictor.get_input_names() for name in input_names: input_tensor = predictor.get_input_handle(name) - output_names = predictor.get_output_names() - output_tensors = [] + output_tensors = get_output_tensors(args, mode, predictor) + return predictor, input_tensor, output_tensors, config + + +def get_output_tensors(args, mode, predictor): + output_names = predictor.get_output_names() + output_tensors = [] + if mode == "rec" and args.rec_algorithm == "CRNN": + output_name = 'softmax_0.tmp_0' + if output_name in output_names: + return [predictor.get_output_handle(output_name)] + else: for output_name in output_names: output_tensor = predictor.get_output_handle(output_name) output_tensors.append(output_tensor) - return predictor, input_tensor, output_tensors, config + return output_tensors def get_infer_gpuid(): -- GitLab