未验证 提交 e60d5598 编写于 作者: W wangguanzhong 提交者: GitHub

[cherry-pick] add py_func check in export_model (#1047)

* add py_func check in export_model

* refine error message
上级 41c2b7a0
...@@ -23,7 +23,12 @@ import paddle.fluid as fluid ...@@ -23,7 +23,12 @@ import paddle.fluid as fluid
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version', 'check_config'] __all__ = [
'check_gpu',
'check_version',
'check_config',
'check_py_func',
]
def check_gpu(use_gpu): def check_gpu(use_gpu):
...@@ -96,3 +101,17 @@ def check_config(cfg): ...@@ -96,3 +101,17 @@ def check_config(cfg):
actual_num_classes)) actual_num_classes))
return cfg return cfg
def check_py_func(program):
for block in program.blocks:
for op in block.ops:
if op.type == 'py_func':
input_arg = op.input_arg_names
output_arg = op.output_arg_names
err = "The program contains py_func with input: {}, "\
"output: {}. It is not supported in Paddle inference "\
"engine. please replace it by paddle ops. For example, "\
"if you use MultiClassSoftNMS, better to replace it by "\
"MultiClassNMS.".format(input_arg, output_arg)
raise Exception(err)
...@@ -28,7 +28,7 @@ from paddle import fluid ...@@ -28,7 +28,7 @@ from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.check import check_config, check_version from ppdet.utils.check import check_config, check_version, check_py_func
import yaml import yaml
import logging import logging
from collections import OrderedDict from collections import OrderedDict
...@@ -195,6 +195,7 @@ def main(): ...@@ -195,6 +195,7 @@ def main():
feed_vars, _ = model.build_inputs(**inputs_def) feed_vars, _ = model.build_inputs(**inputs_def)
test_fetches = model.test(feed_vars) test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True) infer_prog = infer_prog.clone(True)
check_py_func(infer_prog)
exe.run(startup_prog) exe.run(startup_prog)
checkpoint.load_params(exe, infer_prog, cfg.weights) checkpoint.load_params(exe, infer_prog, cfg.weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册