From d4b461eb10c9da8affa4d6daae576bb0b61dcd6d Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 1 Mar 2019 09:51:08 -0600 Subject: [PATCH] Unified ParallelExecutor and Compiler (#15970) * Unified ParallelExecutor and Compiler --- .../fast_threaded_ssa_graph_executor.cc | 4 +- python/paddle/fluid/compiler.py | 72 ++++---- python/paddle/fluid/framework.py | 9 - python/paddle/fluid/parallel_executor.py | 159 +++--------------- 4 files changed, 65 insertions(+), 179 deletions(-) diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index f03646705..d4fbea9d9 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" +#include #include +#include #include #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" @@ -55,7 +57,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( std::vector fetch_ops; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->Get("vars")) { + for (auto &var_map : graph_->Get(details::kGraphVars)) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(*it->second.rbegin()); diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index ab4011383..1b7bdfc33 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -17,7 +17,6 @@ import os import six import sys from .. import compat as cpt -from . import framework from . import core from . import framework @@ -36,6 +35,30 @@ def _place_obj(place): return p +def _is_pserver_mode(main_program): + main = main_program if main_program \ + else default_main_program() + for op in main.global_block().ops: + if op.type in ["send", "recv"]: + return True + return False + + +def get_available_places(use_cuda): + if use_cuda: + gpus_env = os.getenv("FLAGS_selected_gpus") + if gpus_env: + gpus = [int(s) for s in gpus_env.split(",")] + else: + gpus = [i for i in six.moves.range(core.get_cuda_device_count())] + places = [core.CUDAPlace(i) for i in gpus] + else: + cpu_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] + assert places, "no place for execution" + return places + + class CompiledProgram(object): """ Compiles to Graph for execution. @@ -127,8 +150,7 @@ class CompiledProgram(object): self._exec_strategy = ExecutionStrategy() if self._build_strategy is None: self._build_strategy = BuildStrategy() - self._build_strategy.is_distribution = framework.is_pserver_mode( - self._program) + self._build_strategy.is_distribution = _is_pserver_mode(self._program) return self def with_inference_optimize(self, config): @@ -153,9 +175,9 @@ class CompiledProgram(object): def _with_distributed(self): raise NotImplementedError() - def _compile_data_parallel(self): + def _compile_data_parallel(self, use_cuda=False, scope=None): if self._share_vars_from: - if self._scope: + if scope: sys.stderr.write("share_vars_from is set, scope is ignored.\n") if not self._share_vars_from._is_data_parallel: raise ValueError("share_vars_from is not data parallel. Cannot " @@ -166,23 +188,11 @@ class CompiledProgram(object): "var to share.") self._local_scopes = self._share_vars_from._executor.local_scopes() else: + assert scope is not None, "" self._local_scopes = [] - self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace) - if self._exec_strategy.use_cuda: - gpus_env = os.getenv("FLAGS_selected_gpus") - if gpus_env: - gpus = [int(s) for s in gpus_env.split(",")] - else: - gpus = [ - i for i in six.moves.range(core.get_cuda_device_count()) - ] - self._places = [core.CUDAPlace(i) for i in gpus] - else: - cpu_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] - assert self._places, "no place for execution" + self._exec_strategy.use_cuda = use_cuda + self._places = get_available_places(self._exec_strategy.use_cuda) if self._exec_strategy.num_threads == 0: if self._exec_strategy.use_cuda: @@ -197,9 +207,11 @@ class CompiledProgram(object): # FIXME(dzhwinter): enable_inplace should be after memory_optimize # if turn on python memory optimize, turn off the inplace_pass. if self._build_strategy.memory_optimize is None: - self._build_strategy.memory_optimize = False if self._program and self._program._is_mem_optimized else True + self._build_strategy.memory_optimize = False \ + if self._program and self._program._is_mem_optimized else True if self._build_strategy.enable_inplace is None: - self._build_strategy.enable_inplace = False if self._program and self._program._is_mem_optimized else True + self._build_strategy.enable_inplace = False \ + if self._program and self._program._is_mem_optimized else True # TODO(wuyi): trainer endpoings should be passed in through # build_strategy, not program.xxx. @@ -221,12 +233,12 @@ class CompiledProgram(object): places = list(map(_place_obj, self._places)) - return core.ParallelExecutor( - places, - set(self._persistable_vars), - cpt.to_text(self._loss_name) - if self._loss_name else six.u(''), self._scope, self._local_scopes, - self._exec_strategy, self._build_strategy, self._graph) + return core.ParallelExecutor(places, + set(self._persistable_vars), + cpt.to_text(self._loss_name) + if self._loss_name else six.u(''), scope, + self._local_scopes, self._exec_strategy, + self._build_strategy, self._graph) def _compile_inference(self): return core.create_paddle_predictor(self._infer_config) @@ -253,7 +265,9 @@ class CompiledProgram(object): self._scope = scope self._place = place if self._is_data_parallel: - self._executor = self._compile_data_parallel() + self._executor = self._compile_data_parallel( + use_cuda=isinstance(self._place, core.CUDAPlace), + scope=self._scope) elif self._is_inference: self._executor = self._compile_inference() else: diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 54f4bc537..7dc917880 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -87,15 +87,6 @@ def _current_expected_place(): return _imperative_current_expected_place_ -def is_pserver_mode(main_program): - main = main_program if main_program \ - else default_main_program() - for op in main.global_block().ops: - if op.type in ["send", "recv"]: - return True - return False - - class NameScope(object): def __init__(self, name="", parent=None): self._children = dict() diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index fa8d5ef5d..2ebaab3b1 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -13,15 +13,11 @@ # limitations under the License. from __future__ import print_function -import multiprocessing from . import core from . import framework from . import executor -from .. import compat as cpt -import warnings +from . import compiler import sys -import six -import os __all__ = ['ParallelExecutor'] @@ -97,99 +93,27 @@ class ParallelExecutor(object): 'Please use CompiledProgram and Executor. CompiledProgram ' 'is a central place for optimization and Executor is the ' 'unified executor. Example can be found in compiler.py.\n') - # step1: get places, the places are used in run too. - self._places = [] - if use_cuda: - gpus_env = os.getenv("FLAGS_selected_gpus") - if gpus_env: - gpus = [int(s) for s in gpus_env.split(",")] - else: - gpus = [ - i for i in six.moves.range(core.get_cuda_device_count()) - ] - self._places = [core.CUDAPlace(i) for i in gpus] - else: - cpu_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)] - assert self._places, "no place for execution" - # step2: init exec_strategy - if exec_strategy is None: - exec_strategy = ExecutionStrategy() - exec_strategy.use_cuda = use_cuda - if exec_strategy.num_threads == 0: - if use_cuda: - # Experiments on se-resnext shows that too many threads hurt - # performance. Worth tunning for other models in the future. - exec_strategy.num_threads = len(self._places) * 4 - else: - cpu_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - exec_strategy.num_threads = cpu_num * 2 - - # step3: init build_strategy if build_strategy is None: build_strategy = BuildStrategy() build_strategy.num_trainers = num_trainers build_strategy.trainer_id = trainer_id - # FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, - # num_trainers is 1, so the current fields of build_strategy doesn't tell if - # it's distributed model. - build_strategy.is_distribution = framework.is_pserver_mode( - main_program) or num_trainers > 1 - - # step4: get main_program, scope, local_scopes - main = main_program if main_program \ - else framework.default_main_program() - # FIXME(dzhwinter): enable_inplace should be after memory_optimize - # if turn on python memory optimize, turn off the inplace_pass. - if build_strategy.memory_optimize is None: - build_strategy.memory_optimize = False if main._is_mem_optimized else True - if build_strategy.enable_inplace is None: - build_strategy.enable_inplace = False if main._is_mem_optimized else True - scope = scope if scope is not None else executor.global_scope() - - if share_vars_from and not isinstance(share_vars_from, - ParallelExecutor): - raise TypeError("share_vars_from must be ParallelExecutor.") - - local_scopes = share_vars_from.executor.local_scopes()\ - if share_vars_from else [] - - # step5: check trainers_endpoints, it is used for distribution. - trainers_endpoints = main._trainers_endpoints - if num_trainers > 1 and trainers_endpoints: - assert num_trainers == len( - trainers_endpoints), "num_trainers == len(endpoints)" - build_strategy.trainers_endpoints = trainers_endpoints - - # step6: get persistable_vars, places. persistable_vars - # need be broadcast to other local_scope. - persistable_vars = set([ - cpt.to_text(v.name) for v in [ - var for var in main.list_vars() - if var.persistable and var.type != core.VarDesc.VarType.RAW - ] - ]) - - def place_obj(place): - p = core.Place() - p.set_place(place) - return p - - places = list(map(place_obj, self._places)) - # step7: init ParallelExecutor - # ParallelExecutor API will be deprecated, don't support parallel graph. - self._graph = core.Graph(main.desc) + self._places = compiler.get_available_places(use_cuda) + self._scope = scope if scope is not None else executor.global_scope() - self.executor = core.ParallelExecutor( - places, persistable_vars, - cpt.to_text(loss_name) if loss_name else six.u(''), scope, - local_scopes, exec_strategy, build_strategy, self._graph) + main_program = main_program if main_program is not None \ + else framework.default_main_program() - self.scope = scope + self._compiled_program = compiler.CompiledProgram(main_program) + self._compiled_program.with_data_parallel( + loss_name=loss_name, + build_strategy=build_strategy, + exec_strategy=exec_strategy, + share_vars_from=share_vars_from) + self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() + self._executor = executor.Executor(self._place) + self._compiled_program._compile(place=self._place, scope=self._scope) def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): """ @@ -256,56 +180,11 @@ class ParallelExecutor(object): loss = pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name])) """ - if feed is None and feed_dict is not None: - feed = feed_dict - print( - "`feed_dict` is deprecated. Please use `feed=`", - file=sys.stderr) - - if isinstance(feed, dict): - feed_tensor_dict = dict() - for feed_name in feed: - feed_tensor = feed[feed_name] - if not isinstance(feed_tensor, core.LoDTensor): - feed_tensor = core.LoDTensor() - # always set to CPU place, since the tensor need to be splitted - # it is fast in CPU - feed_tensor.set(feed[feed_name], core.CPUPlace()) - feed_tensor_dict[feed_name] = feed_tensor - - self.executor.feed_and_split_tensor_into_local_scopes( - feed_tensor_dict) - elif isinstance(feed, list) or isinstance(feed, tuple): - if len(feed) != len(self._places): - raise ValueError( - "Feed a list of tensor, the list should be the same size as places" - ) - - res = list() - - for i, each in enumerate(feed): - if not isinstance(each, dict): - raise TypeError( - "Each element of feed list should be a dict") - res_dict = dict() - for feed_name in each: - tensor = each[feed_name] - if not isinstance(tensor, core.LoDTensor): - tmp = core.LoDTensor() - tmp.set(tensor, self._places[i]) - tensor = tmp - res_dict[feed_name] = tensor - res.append(res_dict) - self.executor.feed_tensors_into_local_scopes(res) - - fetch_var_name = 'fetch' - self.executor.run(fetch_list, fetch_var_name) - arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() - - if return_numpy: - return executor.as_numpy(arr) - - return [arr[i] for i in range(len(arr))] + return self._executor.run(program=self._compiled_program, + scope=self._scope, + feed=feed, + fetch_list=fetch_list, + return_numpy=return_numpy) @property def device_count(self): -- GitLab