提交 abdd9411 编写于 作者: X Xin Pan

fix

test=develop
上级 c27008c4
...@@ -382,9 +382,11 @@ class Executor(object): ...@@ -382,9 +382,11 @@ class Executor(object):
""" """
Close this executor. Close this executor.
You can no long use this executor after calling this method. You can no longer use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to For the distributed training, this method would free the resource on PServers related to
the current Trainer. the current Trainer.
TODO(typhoonzero): Define "no longer use" meaning? Can user create
a new Executor for the same program and run?
TODO(panyx0718): Why ParallelExecutor doesn't have close? TODO(panyx0718): Why ParallelExecutor doesn't have close?
Example: Example:
...@@ -397,7 +399,7 @@ class Executor(object): ...@@ -397,7 +399,7 @@ class Executor(object):
self.executor.close() self.executor.close()
self._closed = True self._closed = True
def _run_parallel(self, scope, feed, fetch_list, fetch_var_name, def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy): return_numpy):
if isinstance(feed, dict): if isinstance(feed, dict):
feed_tensor_dict = dict() feed_tensor_dict = dict()
...@@ -413,7 +415,7 @@ class Executor(object): ...@@ -413,7 +415,7 @@ class Executor(object):
self.executor.feed_and_split_tensor_into_local_scopes( self.executor.feed_and_split_tensor_into_local_scopes(
feed_tensor_dict) feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple): elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(self._places): if len(feed) != len(program._places):
raise ValueError( raise ValueError(
"Feed a list of tensor, the list should be the same size as places" "Feed a list of tensor, the list should be the same size as places"
) )
...@@ -428,7 +430,7 @@ class Executor(object): ...@@ -428,7 +430,7 @@ class Executor(object):
tensor = each[feed_name] tensor = each[feed_name]
if not isinstance(tensor, core.LoDTensor): if not isinstance(tensor, core.LoDTensor):
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(tensor, self._places[i]) tmp.set(tensor, program._places[i])
tensor = tmp tensor = tmp
res_dict[feed_name] = tensor res_dict[feed_name] = tensor
res.append(res_dict) res.append(res_dict)
...@@ -462,7 +464,7 @@ class Executor(object): ...@@ -462,7 +464,7 @@ class Executor(object):
Args: Args:
program(Program|CompiledProgram): the program that need to run, program(Program|CompiledProgram): the program that need to run,
if not provided, then default_main_program will be used. if not provided, then default_main_program (not compiled) will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData} feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list. fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator. feed_var_name(str): the name for the input variable of feed Operator.
...@@ -525,6 +527,7 @@ class Executor(object): ...@@ -525,6 +527,7 @@ class Executor(object):
self.executor = program._executor self.executor = program._executor
if program._is_data_parallel: if program._is_data_parallel:
return self._run_parallel( return self._run_parallel(
program,
scope=scope, scope=scope,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册