提交 22df230e 编写于 作者: J JiayiFeng

rename 'feed_dict' in ParallelExecutor.run() to 'feed'

上级 a78b9285
......@@ -61,8 +61,8 @@ class ParallelExecutor(object):
main_program=test_program,
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed=feed_dict)
"""
self._places = []
......@@ -123,22 +123,23 @@ class ParallelExecutor(object):
allow_op_delay)
self.scope = scope
def run(self, fetch_list, feed_dict={}):
def run(self, fetch_list, feed={}, feed_dict={}):
"""
:param fetch_list: A list of variable names that will be fetched.
:param feed_dict: A dict mapping for feed variable name to LoDTensor
:param feed: A dict mapping for feed variable name to LoDTensor
or numpy array.
:return: fetched value list.
"""
if not isinstance(feed_dict, dict):
raise TypeError("feed_dict should be a dict")
feed = feed_dict
if not isinstance(feed, dict):
raise TypeError("feed should be a dict")
feed_tensor_dict = {}
for i, feed_name in enumerate(feed_dict):
feed_tensor = feed_dict[feed_name]
for i, feed_name in enumerate(feed):
feed_tensor = feed[feed_name]
if not isinstance(feed_tensor, core.LoDTensor):
feed_tensor = core.LoDTensor()
feed_tensor.set(feed_dict[feed_name], self._act_places[0])
feed_tensor.set(feed[feed_name], self._act_places[0])
feed_tensor_dict[feed_name] = feed_tensor
fetch_var_name = '@FETCHED_VAR_NAME@'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册