From abdd9411b4487d5a67e87f7376918140b5e01045 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 11 Jan 2019 16:01:30 +0800 Subject: [PATCH] fix test=develop --- python/paddle/fluid/executor.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 1a940b30c..0d06d0f2c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -382,9 +382,11 @@ class Executor(object): """ Close this executor. - You can no long use this executor after calling this method. + You can no longer use this executor after calling this method. For the distributed training, this method would free the resource on PServers related to the current Trainer. + TODO(typhoonzero): Define "no longer use" meaning? Can user create + a new Executor for the same program and run? TODO(panyx0718): Why ParallelExecutor doesn't have close? Example: @@ -397,7 +399,7 @@ class Executor(object): self.executor.close() self._closed = True - def _run_parallel(self, scope, feed, fetch_list, fetch_var_name, + def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, return_numpy): if isinstance(feed, dict): feed_tensor_dict = dict() @@ -413,7 +415,7 @@ class Executor(object): self.executor.feed_and_split_tensor_into_local_scopes( feed_tensor_dict) elif isinstance(feed, list) or isinstance(feed, tuple): - if len(feed) != len(self._places): + if len(feed) != len(program._places): raise ValueError( "Feed a list of tensor, the list should be the same size as places" ) @@ -428,7 +430,7 @@ class Executor(object): tensor = each[feed_name] if not isinstance(tensor, core.LoDTensor): tmp = core.LoDTensor() - tmp.set(tensor, self._places[i]) + tmp.set(tensor, program._places[i]) tensor = tmp res_dict[feed_name] = tensor res.append(res_dict) @@ -462,7 +464,7 @@ class Executor(object): Args: program(Program|CompiledProgram): the program that need to run, - if not provided, then default_main_program will be used. + if not provided, then default_main_program (not compiled) will be used. feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData} fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list. feed_var_name(str): the name for the input variable of feed Operator. @@ -525,6 +527,7 @@ class Executor(object): self.executor = program._executor if program._is_data_parallel: return self._run_parallel( + program, scope=scope, feed=feed, fetch_list=fetch_list, -- GitLab