From 35744e7b36f3c7202080feeabc0d8f207839b2e1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 15 Mar 2018 16:30:16 +0800 Subject: [PATCH] Polish code --- paddle/fluid/framework/parallel_executor.cc | 100 ++++++++++++++---- paddle/fluid/framework/parallel_executor.h | 2 + .../tests/unittests/test_parallel_executor.py | 2 +- 3 files changed, 82 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index dd726f1fab0..7af5cc075c2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -20,6 +20,12 @@ limitations under the License. */ namespace paddle { namespace framework { +#ifdef PADDLE_WITH_CUDA + +// FIXME: CHECK the return value of x; +#define NCCL_INVOKE(x) x +#endif + struct OpHandle; struct VarHandle { @@ -71,9 +77,51 @@ class ParallelExecutorPrivate { std::unordered_map local_scopes_; - std::unordered_map - dev_ctxs_; + +#ifdef PADDLE_WITH_CUDA + struct NCCLContext { + std::unique_ptr ctx_; + ncclComm_t comm; + + explicit NCCLContext(int dev_id) { + ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id))); + } + + cudaStream_t stream() const { return ctx_->stream(); } + + int device_id() const { + return boost::get(ctx_->GetPlace()).device; + } + + static void InitNCCLContext(std::map &contexts) { + std::vector comms; + std::vector devs; + comms.resize(contexts.size()); + devs.reserve(contexts.size()); + + for (auto &ctx : contexts) { + devs.push_back(ctx.first); + } + + NCCL_INVOKE(platform::dynload::ncclCommInitAll( + &comms[0], static_cast(contexts.size()), &devs[0])); + + int i = 0; + for (auto &ctx : contexts) { + ctx.second.comm = comms[i++]; + } + } + }; + + std::map communication_streams_; + + NCCLContext &GetNCCLCtx(platform::Place p) { + int dev_id = boost::get(p).device; + return communication_streams_.at(dev_id); + } + +#endif + platform::Place main_place_; std::unordered_mapmain_place_ = places[0]; // Bcast Parameters to all GPUs - if (platform::is_gpu_place(member_->main_place_)) { // Is CUDA - // BCastParamsToGPUs(startup_program); + if (platform::is_gpu_place(member_->main_place_) && + member_->local_scopes_.size() != 1) { // Is CUDA + BuildNCCLCommunicator(); + BCastParamsToGPUs(startup_program); } // Startup Program has been run. All local scopes has correct parameters. @@ -241,20 +291,20 @@ VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name, void ParallelExecutor::BCastParamsToGPUs( const ProgramDesc &startup_program) const { +#ifdef PADDLE_WITH_CUDA auto *main_scope = member_->local_scopes_[member_->main_place_]; + for (auto *var_desc : startup_program.Block(0).AllVars()) { if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { auto &main_tensor = main_scope->FindVar(var_desc->Name())->Get(); - ncclDataType_t data_type = ToNCCLDataType(main_tensor.type()); auto &dims = main_tensor.dims(); size_t numel = main_tensor.numel(); - std::vector> mems; - mems.emplace_back( - const_cast(main_tensor.data()), - new platform::CUDADeviceContext( - boost::get(member_->main_place_))); + std::vector> + mems; + mems.emplace_back(const_cast(main_tensor.data()), + &member_->GetNCCLCtx(member_->main_place_)); for (auto &pair : member_->local_scopes_) { if (pair.first == member_->main_place_) { @@ -265,8 +315,7 @@ void ParallelExecutor::BCastParamsToGPUs( auto *t = local_scope->Var(var_desc->Name())->GetMutable(); t->Resize(dims); mems.emplace_back(t->mutable_data(pair.first, main_tensor.type()), - new platform::CUDADeviceContext( - boost::get(pair.first))); + &member_->GetNCCLCtx(member_->main_place_)); } // TODO(yy): Invoke ncclBCast here. mems, numel, data_type. The mems[0] @@ -274,17 +323,26 @@ void ParallelExecutor::BCastParamsToGPUs( (void)(data_type); (void)(numel); + } + } +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif +} - // Free Communication Ctx - for (auto &pair : mems) { - // Release Communication Ctx +void ParallelExecutor::BuildNCCLCommunicator() const { +#ifdef PADDLE_WITH_CUDA + for (auto &place_pair : member_->local_scopes_) { + auto place = place_pair.first; + int dev_id = boost::get(place).device; - // FIXME: Store CUDA DevCtx to member. Since NCCL All Reduce will use - // this - delete pair.second; - } - } + member_->communication_streams_.emplace( + dev_id, ParallelExecutorPrivate::NCCLContext(dev_id)); } + + ParallelExecutorPrivate::NCCLContext::InitNCCLContext( + member_->communication_streams_); +#endif } std::vector ParallelExecutor::Run( diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ec80f89f0e8..805b7e5aa9f 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -55,6 +55,8 @@ class ParallelExecutor { void ConstructDependencyGraph(const std::unordered_set& params, const ProgramDesc& main_program, const std::string& loss_var_name) const; + + void BuildNCCLCommunicator() const; }; } // namespace framework diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 2b41b2c9b4e..65b43448a44 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -35,7 +35,7 @@ class ParallelExecutor(unittest.TestCase): adam = fluid.optimizer.Adam() adam.minimize(loss) act_places = [] - for each in [fluid.CUDAPlace(0), fluid.CUDAPlace(1)]: + for each in [fluid.CUDAPlace(0)]: p = fluid.core.Place() p.set_place(each) act_places.append(p) -- GitLab