未验证 提交 4c2a06df 编写于 作者: W wanghuancoder 提交者: GitHub

fix feed for new executor (#35803)

* fix feed, test=develop

* delete one test case, test=develop
上级 5574c8cf
......@@ -540,19 +540,24 @@ class _StandaloneExecutor(object):
Returns:
feed:(list|dict) updated feed.
"""
global_block = self._main_program.global_block()
if feed is None:
feed = {}
elif isinstance(feed, dict):
elif isinstance(feed, (list, tuple)):
assert len(feed) == 1, "Not compiled with data parallel"
feed = feed[0]
if not isinstance(feed, dict):
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" %
(type(feed)))
global_block = self._main_program.global_block()
for feed_name in list(feed.keys()):
if not global_block.has_var(feed_name):
feed.pop(feed_name)
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% feed_name)
else:
raise TypeError("Only support feed with `dict`, but received {}".
format(type(feed).__name__))
return feed
......
......@@ -242,9 +242,6 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
def test_with_error(self):
feed = [{'a': np.ones([2, 2], dtype="float32")}]
with self.assertRaises(TypeError):
res = self.run_new_executor(feed)
with self.assertRaises(TypeError):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self._run(feed[0], add_wrong_fetch=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册