提交 abdd9411 编写于 作者: X Xin Pan

fix

test=develop
上级 c27008c4
......@@ -382,9 +382,11 @@ class Executor(object):
"""
Close this executor.
You can no long use this executor after calling this method.
You can no longer use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
TODO(typhoonzero): Define "no longer use" meaning? Can user create
a new Executor for the same program and run?
TODO(panyx0718): Why ParallelExecutor doesn't have close?
Example:
......@@ -397,7 +399,7 @@ class Executor(object):
self.executor.close()
self._closed = True
def _run_parallel(self, scope, feed, fetch_list, fetch_var_name,
def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy):
if isinstance(feed, dict):
feed_tensor_dict = dict()
......@@ -413,7 +415,7 @@ class Executor(object):
self.executor.feed_and_split_tensor_into_local_scopes(
feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(self._places):
if len(feed) != len(program._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
......@@ -428,7 +430,7 @@ class Executor(object):
tensor = each[feed_name]
if not isinstance(tensor, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(tensor, self._places[i])
tmp.set(tensor, program._places[i])
tensor = tmp
res_dict[feed_name] = tensor
res.append(res_dict)
......@@ -462,7 +464,7 @@ class Executor(object):
Args:
program(Program|CompiledProgram): the program that need to run,
if not provided, then default_main_program will be used.
if not provided, then default_main_program (not compiled) will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator.
......@@ -525,6 +527,7 @@ class Executor(object):
self.executor = program._executor
if program._is_data_parallel:
return self._run_parallel(
program,
scope=scope,
feed=feed,
fetch_list=fetch_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册