From b09ba8a796b2ffa31bd9be228ac40373d00be00b Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 27 Jun 2019 13:14:42 +0800 Subject: [PATCH] Cherry pick Fix Bug-prone code of PE (#18355) * update pe reduce config test=release/1.5 * drop the local_exe_scopes of the previous parallel_executor test=release/1.5 --- paddle/fluid/framework/parallel_executor.cc | 64 +++++++++++---------- python/paddle/fluid/compiler.py | 2 + 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index ac4e346959..6fe04cef2d 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -311,8 +311,8 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, member_->global_scope_ = scope; member_->use_cuda_ = exec_strategy.use_cuda_; member_->build_strategy_ = build_strategy; - member_->use_all_reduce_ = - build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; + member_->use_all_reduce_ = member_->build_strategy_.reduce_ == + BuildStrategy::ReduceStrategy::kAllReduce; member_->nranks_ = build_strategy.num_trainers_ * places.size(); if (!member_->use_all_reduce_ && member_->nranks_ == 1) { LOG(INFO) << "If you set build_strategy.reduce with 'Reduce'," @@ -348,7 +348,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } std::vector graphs; - if (build_strategy.async_mode_) { + if (member_->build_strategy_.async_mode_) { PADDLE_ENFORCE(!member_->use_cuda_, "gpu mode does not support async_mode_ now!"); graphs.push_back(graph); @@ -362,9 +362,10 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // FIXME(Yancey1989): parallel graph mode get better performance // in GPU allreduce distributed training. Need an elegant way to // choice the execution strategy. - build_strategy.enable_parallel_graph_ = - EnableParallelGraphExecution(*graph, exec_strategy, build_strategy); - if (build_strategy.enable_parallel_graph_) { + member_->build_strategy_.enable_parallel_graph_ = + EnableParallelGraphExecution(*graph, exec_strategy, + member_->build_strategy_); + if (member_->build_strategy_.enable_parallel_graph_) { LOG(INFO) << "The Executor would execute the graph by ParallelGraph " "Execution which can get better performance," << "you can force it off by env FLAGS_enable_parallel_graph=0"; @@ -372,7 +373,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, if (member_->use_cuda_ && member_->nranks_ > 1) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - member_->InitOrGetNCCLCommunicator(scope, build_strategy); + member_->InitOrGetNCCLCommunicator(scope, member_->build_strategy_); // Initialize device context's nccl comm, will be used by normal // Operators like sync_batch_norm, and collective ops. @@ -395,7 +396,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } // broadcast parameters from the 0th device to others: auto need_broadcast = [&]() -> bool { - if (build_strategy.num_trainers_ > 1) { + if (member_->build_strategy_.num_trainers_ > 1) { // 1. num_tariners would be grater than 1 for nccl distributed training. return true; } else if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { @@ -407,7 +408,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, }; // Bcast Parameters to all GPUs if (need_broadcast()) { - BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_); + BCastParamsToDevices(bcast_vars, member_->build_strategy_.trainer_id_); } // Startup Program has been run. All local scopes has correct parameters. @@ -416,39 +417,40 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // ncclOp std::vector async_graphs(places.size()); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - if (build_strategy.async_mode_) { + if (member_->build_strategy_.async_mode_) { VLOG(3) << "use local async mode"; - graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, 1, - member_->use_cuda_, member_->nccl_ctxs_); + graph = member_->build_strategy_.Apply( + graph, {member_->places_[0]}, loss_var_name, + {member_->local_scopes_[0]}, 1, member_->use_cuda_, + member_->nccl_ctxs_); for (size_t i = 1; i < member_->places_.size(); ++i) { - graphs[i] = - build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name, - {member_->local_scopes_[i]}, 1, - member_->use_cuda_, member_->nccl_ctxs_); + graphs[i] = member_->build_strategy_.Apply( + graphs[i], {member_->places_[i]}, loss_var_name, + {member_->local_scopes_[i]}, 1, member_->use_cuda_, + member_->nccl_ctxs_); async_graphs[i] = graphs[i]; } } else { - graph = build_strategy.Apply(graph, member_->places_, loss_var_name, - member_->local_scopes_, member_->nranks_, - member_->use_cuda_, member_->nccl_ctxs_); + graph = member_->build_strategy_.Apply( + graph, member_->places_, loss_var_name, member_->local_scopes_, + member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_); } #else - if (build_strategy.async_mode_) { + if (member_->build_strategy_.async_mode_) { VLOG(3) << "use local async mode"; - graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, 1, - member_->use_cuda_); + graph = member_->build_strategy_.Apply( + graph, {member_->places_[0]}, loss_var_name, + {member_->local_scopes_[0]}, 1, member_->use_cuda_); for (size_t i = 1; i < member_->places_.size(); ++i) { - graphs[i] = build_strategy.Apply( + graphs[i] = member_->build_strategy_.Apply( graphs[i], {member_->places_[i]}, loss_var_name, {member_->local_scopes_[i]}, 1, member_->use_cuda_); async_graphs[i] = graphs[i]; } } else { - graph = build_strategy.Apply(graph, member_->places_, loss_var_name, - member_->local_scopes_, member_->nranks_, - member_->use_cuda_); + graph = member_->build_strategy_.Apply( + graph, member_->places_, loss_var_name, member_->local_scopes_, + member_->nranks_, member_->use_cuda_); } #endif @@ -489,11 +491,11 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } } - if (build_strategy.async_mode_) { + if (member_->build_strategy_.async_mode_) { VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset(new details::AsyncSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->places_, async_graphs)); - } else if (build_strategy.enable_parallel_graph_) { + } else if (member_->build_strategy_.enable_parallel_graph_) { VLOG(3) << "use ParallelSSAGraphExecutor"; #ifdef PADDLE_WITH_CUDA // TODO(Yancey1989): Remove passing in the main_program when @@ -517,7 +519,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; - if (!build_strategy.async_mode_) { + if (!member_->build_strategy_.async_mode_) { member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 87a6ce0881..7ffecb69a9 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -274,6 +274,8 @@ class CompiledProgram(object): "share_vars_from is not compiled and run, so there is no " "var to share.") self._local_scopes = self._share_vars_from._executor.local_scopes() + # drop the local_exe_scopes of the previous parallel_executor + self._share_vars_from._executor.drop_local_exe_scopes() else: assert scope is not None, "" self._local_scopes = [] -- GitLab