未验证 提交 633d0552 编写于 作者: D Double_V 提交者: GitHub

fix bug for PPOCRV2_rec_pact export model and inference (#5903)

* fix bug for rec distill pact infer

* Update rec_postprocess.py
上级 d529f7ae
......@@ -127,6 +127,7 @@ def main():
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(quanter, model.model_list[idx], infer_shape,
sub_model_save_path, logger)
......
......@@ -87,7 +87,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
if isinstance(preds, (tuple, list)):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
......
......@@ -305,12 +305,22 @@ def create_predictor(args, mode, logger):
input_names = predictor.get_input_names()
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and args.rec_algorithm == "CRNN":
output_name = 'softmax_0.tmp_0'
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors, config
return output_tensors
def get_infer_gpuid():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册