未验证 提交 bc02adfa 编写于 作者: W wangguanzhong 提交者: GitHub

add py_func check in export_model (#1046)

* add py_func check in export_model

* refine error message
上级 c3aad6b6
......@@ -25,7 +25,12 @@ import six
import paddle.version as fluid_version
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version', 'check_config']
__all__ = [
'check_gpu',
'check_version',
'check_config',
'check_py_func',
]
def check_gpu(use_gpu):
......@@ -107,3 +112,17 @@ def check_config(cfg):
actual_num_classes))
return cfg
def check_py_func(program):
for block in program.blocks:
for op in block.ops:
if op.type == 'py_func':
input_arg = op.input_arg_names
output_arg = op.output_arg_names
err = "The program contains py_func with input: {}, "\
"output: {}. It is not supported in Paddle inference "\
"engine. please replace it by paddle ops. For example, "\
"if you use MultiClassSoftNMS, better to replace it by "\
"MultiClassNMS.".format(input_arg, output_arg)
raise Exception(err)
......@@ -28,7 +28,7 @@ from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.cli import ArgsParser
import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.check import check_config, check_version
from ppdet.utils.check import check_config, check_version, check_py_func
import yaml
import logging
from collections import OrderedDict
......@@ -195,6 +195,7 @@ def main():
feed_vars, _ = model.build_inputs(**inputs_def)
test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True)
check_py_func(infer_prog)
exe.run(startup_prog)
checkpoint.load_params(exe, infer_prog, cfg.weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册