未验证 提交 dfe4e67e 编写于 作者: A Aurelius84 提交者: GitHub

Add friendly Error message in save_inference_model (#25617)

上级 ca1185d0
...@@ -1064,6 +1064,13 @@ def prepend_feed_ops(inference_program, ...@@ -1064,6 +1064,13 @@ def prepend_feed_ops(inference_program,
persistable=True) persistable=True)
for i, name in enumerate(feed_target_names): for i, name in enumerate(feed_target_names):
if not global_block.has_var(name):
raise ValueError(
"The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
"Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
"if '{name}' is not involved in the target_vars calculation.".
format(
i=i, name=name))
out = global_block.var(name) out = global_block.var(name)
global_block._prepend_op( global_block._prepend_op(
type='feed', type='feed',
......
...@@ -48,5 +48,26 @@ class TestSaveLoadAPIError(unittest.TestCase): ...@@ -48,5 +48,26 @@ class TestSaveLoadAPIError(unittest.TestCase):
vars="vars") vars="vars")
class TestSaveInferenceModelAPIError(unittest.TestCase):
def test_useless_feeded_var_names(self):
start_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.data(name='x', shape=[10, 16], dtype='float32')
y = fluid.data(name='y', shape=[10, 16], dtype='float32')
z = fluid.layers.fc(x, 4)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(start_prog)
with self.assertRaisesRegexp(
ValueError, "not involved in the target_vars calculation"):
fluid.io.save_inference_model(
dirname='./model',
feeded_var_names=['x', 'y'],
target_vars=[z],
executor=exe,
main_program=main_prog)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册