From e60d559851b17bc36abfb8c70662116065042bf4 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Fri, 10 Jul 2020 14:42:34 +0800 Subject: [PATCH] [cherry-pick] add py_func check in export_model (#1047) * add py_func check in export_model * refine error message --- ppdet/utils/check.py | 21 ++++++++++++++++++++- tools/export_model.py | 3 ++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/ppdet/utils/check.py b/ppdet/utils/check.py index 10c4478ac..3cca56ecd 100644 --- a/ppdet/utils/check.py +++ b/ppdet/utils/check.py @@ -23,7 +23,12 @@ import paddle.fluid as fluid import logging logger = logging.getLogger(__name__) -__all__ = ['check_gpu', 'check_version', 'check_config'] +__all__ = [ + 'check_gpu', + 'check_version', + 'check_config', + 'check_py_func', +] def check_gpu(use_gpu): @@ -96,3 +101,17 @@ def check_config(cfg): actual_num_classes)) return cfg + + +def check_py_func(program): + for block in program.blocks: + for op in block.ops: + if op.type == 'py_func': + input_arg = op.input_arg_names + output_arg = op.output_arg_names + err = "The program contains py_func with input: {}, "\ + "output: {}. It is not supported in Paddle inference "\ + "engine. please replace it by paddle ops. For example, "\ + "if you use MultiClassSoftNMS, better to replace it by "\ + "MultiClassNMS.".format(input_arg, output_arg) + raise Exception(err) diff --git a/tools/export_model.py b/tools/export_model.py index 2cdf9c58d..786fee49b 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -28,7 +28,7 @@ from paddle import fluid from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.cli import ArgsParser import ppdet.utils.checkpoint as checkpoint -from ppdet.utils.check import check_config, check_version +from ppdet.utils.check import check_config, check_version, check_py_func import yaml import logging from collections import OrderedDict @@ -195,6 +195,7 @@ def main(): feed_vars, _ = model.build_inputs(**inputs_def) test_fetches = model.test(feed_vars) infer_prog = infer_prog.clone(True) + check_py_func(infer_prog) exe.run(startup_prog) checkpoint.load_params(exe, infer_prog, cfg.weights) -- GitLab