diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 8cd8bc39941c59519d51cca735acdf2f96eb12cb..4c7537d8d5c8ebc412a717b7220ac0dab12e2939 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -540,19 +540,24 @@ class _StandaloneExecutor(object): Returns: feed:(list|dict) updated feed. """ - global_block = self._main_program.global_block() if feed is None: feed = {} - elif isinstance(feed, dict): - for feed_name in list(feed.keys()): - if not global_block.has_var(feed_name): - feed.pop(feed_name) - warnings.warn( - "The variable %s is not found in program. It is not declared or is pruned." - % feed_name) - else: - raise TypeError("Only support feed with `dict`, but received {}". - format(type(feed).__name__)) + elif isinstance(feed, (list, tuple)): + assert len(feed) == 1, "Not compiled with data parallel" + feed = feed[0] + + if not isinstance(feed, dict): + raise TypeError( + "feed requires dict as its Parameter. But you passed in %s" % + (type(feed))) + + global_block = self._main_program.global_block() + for feed_name in list(feed.keys()): + if not global_block.has_var(feed_name): + feed.pop(feed_name) + warnings.warn( + "The variable %s is not found in program. It is not declared or is pruned." + % feed_name) return feed diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index 425c62ad9e26453694e564d31f26d61282298575..f269979746a08e94dc98f42b573d2079517f4e98 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -242,9 +242,6 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): def test_with_error(self): feed = [{'a': np.ones([2, 2], dtype="float32")}] - with self.assertRaises(TypeError): - res = self.run_new_executor(feed) - with self.assertRaises(TypeError): os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1' self._run(feed[0], add_wrong_fetch=True)