From fe7ed285d131ba99e82538e76cb7ac5381e97809 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 14:49:02 +0800 Subject: [PATCH] Extract NCCLCtxMap --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 1 + .../fluid/framework/details/op_handle_base.cc | 84 +++++++++++++ .../fluid/framework/details/op_handle_base.h | 48 ++++++++ paddle/fluid/framework/details/var_handle.h | 4 +- paddle/fluid/framework/parallel_executor.cc | 114 +++--------------- paddle/fluid/platform/nccl_helper.h | 46 +++++++ 7 files changed, 196 insertions(+), 103 deletions(-) create mode 100644 paddle/fluid/framework/details/op_handle_base.cc create mode 100644 paddle/fluid/framework/details/op_handle_base.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 9d2dc290282..afc7ec9d663 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -88,7 +88,7 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table feed_fetch_method) cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope - framework_proto backward glog lod_rank_table feed_fetch_method executor simple_threadpool var_handle) + framework_proto backward glog lod_rank_table simple_threadpool var_handle op_handle_base) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5074715e2ef..d9bdf0b94d6 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -1 +1,2 @@ cc_library(var_handle SRCS var_handle.cc DEPS place) +cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context) diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc new file mode 100644 index 00000000000..094b62cc945 --- /dev/null +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/op_handle_base.h" + +namespace paddle { +namespace framework { +namespace details { +std::string OpHandleBase::DebugString() const { + std::stringstream ss; + ss << "("; + for (auto *var : inputs_) { + ss << var->DebugString() << ", "; + } + ss << ") --> ("; + for (auto *var : outputs_) { + ss << var->DebugString() << ", "; + } + ss << ")\n"; + return ss.str(); +} + +OpHandleBase::~OpHandleBase() {} + +void OpHandleBase::Run(bool use_event) { +#ifdef PADDLE_WITH_CUDA + if (events_.empty() && use_event) { + for (auto &p : dev_ctx_) { + int dev_id = boost::get(p.first).device; + cudaSetDevice(dev_id); + cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming); + } + } +#else + PADDLE_ENFORCE(!use_event); +#endif + + RunImpl(); + +#ifdef PADDLE_WITH_CUDA + if (use_event) { + for (auto &p : dev_ctx_) { + int dev_id = boost::get(p.first).device; + auto stream = + static_cast(p.second)->stream(); + cudaEventRecord(events_.at(dev_id), stream); + } + } +#endif +} + +void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { + for (auto &dev_ctx : dev_ctx_) { + dev_ctx.second->Wait(); + } + } else { + auto stream = + static_cast(waited_dev)->stream(); + for (auto &ev : events_) { + PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); + } + } +#else + for (auto &dev_ctx : dev_ctx_) { + dev_ctx.second->Wait(); + } +#endif +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h new file mode 100644 index 00000000000..bdfd1f78ad8 --- /dev/null +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -0,0 +1,48 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/platform/device_context.h" +namespace paddle { +namespace framework { +namespace details { + +struct OpHandleBase { + std::vector inputs_; + std::vector outputs_; + std::unordered_map + dev_ctx_; + +#ifdef PADDLE_WITH_CUDA + std::unordered_map events_; +#endif + + std::string DebugString() const; + + virtual ~OpHandleBase(); + + void Run(bool use_event); + + virtual void Wait(platform::DeviceContext *waited_dev); + + protected: + virtual void RunImpl() = 0; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index 613ff901b15..893cc15f6c8 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -21,10 +21,8 @@ namespace paddle { namespace framework { - -struct OpHandleBase; - namespace details { +struct OpHandleBase; // VarHandleBase is the var node in the dependency graph. // A variable can only be generated by a single operator. i.e. diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 2b094eba1e1..3c24fa4bdf6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -14,86 +14,22 @@ limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include "ThreadPool.h" -#include "executor.h" #include "lod_tensor.h" #include "lod_tensor_array.h" #include "op_registry.h" +#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/feed_fetch_type.h" -#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/platform/nccl_helper.h" namespace paddle { namespace framework { using details::DummyVarHandle; +using details::OpHandleBase; using details::VarHandle; using details::VarHandleBase; -struct OpHandleBase { - std::vector inputs_; - std::vector outputs_; - std::unordered_map - dev_ctx_; - - std::unordered_map events_; - - std::string DebugString() { - std::stringstream ss; - ss << "("; - for (auto *var : inputs_) { - ss << var->DebugString() << ", "; - } - ss << ") --> ("; - for (auto *var : outputs_) { - ss << var->DebugString() << ", "; - } - ss << ")\n"; - return ss.str(); - } - - virtual ~OpHandleBase() {} - - void Run(bool use_event) { - if (events_.empty() && use_event) { - for (auto &p : dev_ctx_) { - int dev_id = boost::get(p.first).device; - cudaSetDevice(dev_id); - cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming); - } - } - - RunImpl(); - - if (use_event) { - for (auto &p : dev_ctx_) { - int dev_id = boost::get(p.first).device; - auto stream = - static_cast(p.second)->stream(); - cudaEventRecord(events_.at(dev_id), stream); - } - } - } - - virtual void Wait(platform::DeviceContext *waited_dev) { - if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { - for (auto &dev_ctx : dev_ctx_) { - dev_ctx.second->Wait(); - } - } else { - auto stream = - static_cast(waited_dev)->stream(); - for (auto &ev : events_) { - PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); - } - } - } - - protected: - virtual void RunImpl() = 0; -}; - struct ScaleLossGradOpHandle : public OpHandleBase { float coeff_; Scope *scope_; @@ -193,12 +129,7 @@ class ParallelExecutorPrivate { std::vector local_scopes_; Scope *global_scope_; - std::unordered_map communication_streams_; - - platform::NCCLContext &GetNCCLCtx(platform::Place p) { - int dev_id = boost::get(p).device; - return communication_streams_.at(dev_id); - } + std::unique_ptr nccl_ctxs_; platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) { if (platform::is_cpu_place(place) || local_scopes_.size() == 1) { @@ -206,7 +137,7 @@ class ParallelExecutorPrivate { platform::DeviceContextPool::Instance().Get(place)); } else { #ifdef PADDLE_WITH_CUDA - return GetNCCLCtx(place).ctx_.get(); + return nccl_ctxs_->DevCtx(place); #else PADDLE_THROW("Not compiled with CUDA") #endif @@ -293,15 +224,12 @@ class ParallelExecutorPrivate { struct NCCLAllReduceOpHandle : public OpHandleBase { const std::vector &local_scopes_; const std::vector &places_; - const std::unordered_map &communication_ctxs_; + const platform::NCCLContextMap &nccl_ctxs_; - explicit NCCLAllReduceOpHandle( - const std::vector &local_scopes, - const std::vector &places, - const std::unordered_map &ctxs) - : local_scopes_(local_scopes), - places_(places), - communication_ctxs_(ctxs) {} + explicit NCCLAllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap &ctxs) + : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {} void Wait(platform::DeviceContext *waited_dev) override { OpHandleBase::Wait(waited_dev); @@ -343,7 +271,7 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { if (numel == 0) { numel = static_cast(lod_tensor.numel()); } - auto &nccl_ctx = communication_ctxs_.at(dev_id); + auto &nccl_ctx = nccl_ctxs_.at(dev_id); PADDLE_ENFORCE(platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm_, nccl_ctx.stream())); @@ -491,8 +419,7 @@ void ParallelExecutor::ConstructDependencyGraph( if (grads.count(og) != 0) { // is param grad // Insert NCCL AllReduce Op member_->ops_.emplace_back(new NCCLAllReduceOpHandle( - member_->local_scopes_, member_->places_, - member_->communication_streams_)); + member_->local_scopes_, member_->places_, *member_->nccl_ctxs_)); auto *op_handle = member_->ops_.back().get(); for (size_t i = 0; i < member_->places_.size(); ++i) { @@ -598,15 +525,12 @@ void ParallelExecutor::BCastParamsToGPUs( buffer = t->mutable_data(place, main_tensor.type()); } - auto &nccl_ctx = member_->GetNCCLCtx(place); + auto &nccl_ctx = member_->nccl_ctxs_->at(place); platform::dynload::ncclBcast(buffer, numel, data_type, 0, nccl_ctx.comm_, nccl_ctx.stream()); } } - - for (auto &stream : member_->communication_streams_) { - stream.second.ctx_->Wait(); - } + member_->nccl_ctxs_->WaitAll(); } #else PADDLE_THROW("Not compiled with CUDA"); @@ -615,15 +539,7 @@ void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BuildNCCLCommunicator() const { #ifdef PADDLE_WITH_CUDA - for (auto &place : member_->places_) { - int dev_id = boost::get(place).device; - - member_->communication_streams_.emplace(dev_id, - platform::NCCLContext(dev_id)); - } - - platform::NCCLContext::InitNCCLContext(member_->communication_streams_, - member_->places_); + member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); #endif } @@ -682,7 +598,7 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, op->offset_ = i; op->local_scopes_ = &member_->local_scopes_; for (auto &p : member_->places_) { - op->dev_ctx_[p] = member_->GetNCCLCtx(p).ctx_.get(); + op->dev_ctx_[p] = member_->nccl_ctxs_->DevCtx(p); } for (auto *var : vars) { diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 3db846b0247..29990043206 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -87,5 +87,51 @@ struct NCCLContext { } }; +struct NCCLContextMap { + std::unordered_map contexts_; + std::vector order_; + + NCCLContextMap(const std::vector &places) { + order_.reserve(places.size()); + for (auto &p : places) { + int dev_id = boost::get(p).device; + order_.emplace_back(dev_id); + contexts_.emplace(dev_id, NCCLContext(dev_id)); + } + PADDLE_ENFORCE_EQ( + order_.size(), contexts_.size(), + "NCCL Context Map does not support contain two or more same device"); + + std::vector comms; + comms.resize(order_.size()); + + PADDLE_ENFORCE(platform::dynload::ncclCommInitAll( + &comms[0], static_cast(order_.size()), &order_[0])); + + int i = 0; + for (auto &dev_id : order_) { + contexts_.at(dev_id).comm_ = comms[i++]; + } + } + + CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); } + + CUDADeviceContext *DevCtx(platform::Place p) const { + return DevCtx(boost::get(p).device); + } + + const NCCLContext &at(platform::Place p) const { + return this->at(boost::get(p).device); + } + + const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); } + + void WaitAll() { + for (auto &p : contexts_) { + p.second.ctx_->Wait(); + } + } +}; + } // namespace platform } // namespace paddle -- GitLab