From 4c2a06df6afab19399e9b83b3cdbaaaf699548da Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 21 Sep 2021 22:19:08 -0500 Subject: [PATCH] fix feed for new executor (#35803) * fix feed, test=develop * delete one test case, test=develop --- python/paddle/fluid/executor.py | 27 +++++++++++-------- .../interpreter/test_standalone_executor.py | 3 --- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8cd8bc39941..4c7537d8d5c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -540,19 +540,24 @@ class _StandaloneExecutor(object): Returns: feed:(list|dict) updated feed. """ - global_block = self._main_program.global_block() if feed is None: feed = {} - elif isinstance(feed, dict): - for feed_name in list(feed.keys()): - if not global_block.has_var(feed_name): - feed.pop(feed_name) - warnings.warn( - "The variable %s is not found in program. It is not declared or is pruned." - % feed_name) - else: - raise TypeError("Only support feed with `dict`, but received {}". - format(type(feed).__name__)) + elif isinstance(feed, (list, tuple)): + assert len(feed) == 1, "Not compiled with data parallel" + feed = feed[0] + + if not isinstance(feed, dict): + raise TypeError( + "feed requires dict as its Parameter. But you passed in %s" % + (type(feed))) + + global_block = self._main_program.global_block() + for feed_name in list(feed.keys()): + if not global_block.has_var(feed_name): + feed.pop(feed_name) + warnings.warn( + "The variable %s is not found in program. It is not declared or is pruned." + % feed_name) return feed diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index 425c62ad9e2..f269979746a 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -242,9 +242,6 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): def test_with_error(self): feed = [{'a': np.ones([2, 2], dtype="float32")}] - with self.assertRaises(TypeError): - res = self.run_new_executor(feed) - with self.assertRaises(TypeError): os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' self._run(feed[0], add_wrong_fetch=True) -- GitLab