提交 22df230e 编写于 作者: J JiayiFeng

rename 'feed_dict' in ParallelExecutor.run() to 'feed'

上级 a78b9285
...@@ -61,8 +61,8 @@ class ParallelExecutor(object): ...@@ -61,8 +61,8 @@ class ParallelExecutor(object):
main_program=test_program, main_program=test_program,
share_vars_from=train_exe) share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict) train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
""" """
self._places = [] self._places = []
...@@ -123,22 +123,23 @@ class ParallelExecutor(object): ...@@ -123,22 +123,23 @@ class ParallelExecutor(object):
allow_op_delay) allow_op_delay)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed_dict={}): def run(self, fetch_list, feed={}, feed_dict={}):
""" """
:param fetch_list: A list of variable names that will be fetched. :param fetch_list: A list of variable names that will be fetched.
:param feed_dict: A dict mapping for feed variable name to LoDTensor :param feed: A dict mapping for feed variable name to LoDTensor
or numpy array. or numpy array.
:return: fetched value list. :return: fetched value list.
""" """
if not isinstance(feed_dict, dict): feed = feed_dict
raise TypeError("feed_dict should be a dict") if not isinstance(feed, dict):
raise TypeError("feed should be a dict")
feed_tensor_dict = {} feed_tensor_dict = {}
for i, feed_name in enumerate(feed_dict): for i, feed_name in enumerate(feed):
feed_tensor = feed_dict[feed_name] feed_tensor = feed[feed_name]
if not isinstance(feed_tensor, core.LoDTensor): if not isinstance(feed_tensor, core.LoDTensor):
feed_tensor = core.LoDTensor() feed_tensor = core.LoDTensor()
feed_tensor.set(feed_dict[feed_name], self._act_places[0]) feed_tensor.set(feed[feed_name], self._act_places[0])
feed_tensor_dict[feed_name] = feed_tensor feed_tensor_dict[feed_name] = feed_tensor
fetch_var_name = '@FETCHED_VAR_NAME@' fetch_var_name = '@FETCHED_VAR_NAME@'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册