From 35cda13e9fd65cb2f41c5e7e58fe513c19a84f5b Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sat, 29 Dec 2018 17:09:28 +0800 Subject: [PATCH] fix unittest test=develop --- .../details/parallel_ssa_graph_executor.cc | 8 +++- paddle/fluid/framework/parallel_executor.cc | 42 +++++++++---------- paddle/fluid/framework/parallel_executor.h | 7 ++++ paddle/fluid/pybind/pybind.cc | 2 +- python/paddle/fluid/parallel_executor.py | 2 +- 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 2377f2c963d..bb1f415128e 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -28,7 +28,13 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( places_(std::move(places)), graphs_(std::move(graphs)) { PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); - // do not use threadpool for each graph execution. + + // set the correct size of thread pool to each device. + strategy_.num_threads_ = strategy_.num_threads_ < places_.size() + ? 1UL + : strategy_.num_threads_ / places_.size(); + VLOG(1) << "set num_threads: " << strategy_.num_threads_ + << " to schedule operators on each device."; for (size_t i = 0; i < places.size(); ++i) { executors_.emplace_back(new details::ThreadedSSAGraphExecutor( strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 934cf34cbd4..176c1db349c 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -21,10 +21,6 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -#include "paddle/fluid/platform/nccl_helper.h" -#endif - #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" @@ -39,6 +35,8 @@ limitations under the License. */ DEFINE_string(pe_profile_fname, "", "Profiler filename for PE, which generated by gperftools." "Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable."); +DEFINE_bool(enable_parallel_graph, true, + "Force disable parallel graph execution mode if set false."); namespace paddle { namespace framework { @@ -211,15 +209,6 @@ ParallelExecutor::ParallelExecutor( "the number of places must be greater than 1."); } - // FIXME(Yancey1989): parallel graph mode get better performance - // in GPU allreduce distributed training. Need an elegant way to - // choice the execution strategy. - build_strategy.enable_parallel_graph_ = - EnableParallelGraphExecution(main_program, exec_strategy, build_strategy); - - VLOG(1) << "Enable ParallelGraph Execution: " - << build_strategy.enable_parallel_graph_; - // Step 1. Bcast the bcast_vars to devs. // Create local scopes if (local_scopes.empty()) { @@ -236,24 +225,35 @@ ParallelExecutor::ParallelExecutor( } } + // FIXME(Yancey1989): parallel graph mode get better performance + // in GPU allreduce distributed training. Need an elegant way to + // choice the execution strategy. + build_strategy.enable_parallel_graph_ = + EnableParallelGraphExecution(main_program, exec_strategy, build_strategy); + + VLOG(1) << "Enable ParallelGraph Execution: " + << build_strategy.enable_parallel_graph_; + if (member_->use_cuda_) { // Bcast Parameters to all GPUs #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + ncclUniqueId *nccl_id = nullptr; + // gen_nccl_id operator can broadcast the ncclUniqueId for nccl2 collective + // distributed training auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); - std::unique_ptr nccl_id; - // nccl collective would broadcast ncclUniqueId by gen_nccl_id operator. if (nccl_id_var != nullptr) { - nccl_id.reset(nccl_id_var->GetMutable()); + nccl_id = nccl_id_var->GetMutable(); } if (build_strategy.enable_parallel_graph_ && member_->nranks_ > 1UL) { - if (nccl_id.get() == nullptr) { - nccl_id.reset(new ncclUniqueId()); - platform::dynload::ncclGetUniqueId(nccl_id.get()); + if (nccl_id == nullptr) { + local_nccl_id_.reset(new ncclUniqueId()); + platform::dynload::ncclGetUniqueId(local_nccl_id_.get()); + nccl_id = local_nccl_id_.get(); } } member_->nccl_ctxs_.reset(new platform::NCCLContextMap( - member_->places_, nccl_id.get(), num_trainers, trainer_id)); + member_->places_, nccl_id, num_trainers, trainer_id)); #else PADDLE_THROW("Not compiled with CUDA"); #endif @@ -492,7 +492,7 @@ bool ParallelExecutor::EnableParallelGraphExecution( if (build_strategy.enable_sequential_execution_ || exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) enable_parallel_graph = false; - return enable_parallel_graph; + return enable_parallel_graph && FLAGS_enable_parallel_graph; } ParallelExecutor::~ParallelExecutor() { diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index dc70894dbdb..49d3f0d3f6f 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -28,6 +28,10 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/nccl_helper.h" +#endif + namespace paddle { namespace framework { @@ -73,6 +77,9 @@ class ParallelExecutor { const BuildStrategy &build_strategy) const; ParallelExecutorPrivate *member_; +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + std::unique_ptr local_nccl_id_; +#endif }; } // namespace framework diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d664107d570..1473603a747 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -810,7 +810,7 @@ All parameter, weight, gradient are variables in Paddle. If :math:`num\_threads=1`, all the operators will execute one by one, but the order maybe difference between iterations. If it is not set, it will be set in ParallelExecutor according to the - device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU, + device type and device count, for GPU, :math:`num\_threads=device\_count`, for CPU, :math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor. if it is not set, ParallelExecutor will get the cpu count by calling `multiprocessing.cpu_count()`. Default 0.)DOC") diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index c97a93ec36d..97099612868 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -117,7 +117,7 @@ class ParallelExecutor(object): if use_cuda: # Experiments on se-resnext shows that too many threads hurt # performance. Worth tunning for other models in the future. - exec_strategy.num_threads = len(self._places) * 4 + exec_strategy.num_threads = len(self._places) else: cpu_num = int( os.environ.get('CPU_NUM', multiprocessing.cpu_count())) -- GitLab