From dfe4e67e7aac67f3741bffd2031a1bd0cd8f40ce Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 21 Jul 2020 10:38:55 +0800 Subject: [PATCH] Add friendly Error message in save_inference_model (#25617) --- python/paddle/fluid/io.py | 7 +++++++ .../tests/unittests/test_io_save_load.py | 21 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index fadd247e0df..260033f9ef0 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -1064,6 +1064,13 @@ def prepend_feed_ops(inference_program, persistable=True) for i, name in enumerate(feed_target_names): + if not global_block.has_var(name): + raise ValueError( + "The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. " + "Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names " + "if '{name}' is not involved in the target_vars calculation.". + format( + i=i, name=name)) out = global_block.var(name) global_block._prepend_op( type='feed', diff --git a/python/paddle/fluid/tests/unittests/test_io_save_load.py b/python/paddle/fluid/tests/unittests/test_io_save_load.py index 01665597fac..c532c1bdbaa 100644 --- a/python/paddle/fluid/tests/unittests/test_io_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_io_save_load.py @@ -48,5 +48,26 @@ class TestSaveLoadAPIError(unittest.TestCase): vars="vars") +class TestSaveInferenceModelAPIError(unittest.TestCase): + def test_useless_feeded_var_names(self): + start_prog = fluid.Program() + main_prog = fluid.Program() + with fluid.program_guard(main_prog, start_prog): + x = fluid.data(name='x', shape=[10, 16], dtype='float32') + y = fluid.data(name='y', shape=[10, 16], dtype='float32') + z = fluid.layers.fc(x, 4) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(start_prog) + with self.assertRaisesRegexp( + ValueError, "not involved in the target_vars calculation"): + fluid.io.save_inference_model( + dirname='./model', + feeded_var_names=['x', 'y'], + target_vars=[z], + executor=exe, + main_program=main_prog) + + if __name__ == '__main__': unittest.main() -- GitLab