diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 207dd1b93a0efbea95673793fe441b83a8b18f2a..2258f2168cc6c7a40df5dcd8a23626d7f89f32c5 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -12,16 +12,16 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) if(WITH_GPU) - nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + nv_library(nccl_all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) else() - cc_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) - set(multi_devices_graph_builder_deps) + set(multi_devices_graph_builder_deps all_reduce_op_handle) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) endif() diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc similarity index 88% rename from paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc rename to paddle/fluid/framework/details/all_reduce_op_handle.cc index ab5dc6761370f50b4afc69eb2150f17582930b0e..aafb5bf07d9af92775f249c5f8e69330736c9696 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -13,8 +13,8 @@ // limitations under the License. #include +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/container_cast.h" -#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/variable_visitor.h" @@ -23,25 +23,23 @@ namespace framework { namespace details { #ifdef PADDLE_WITH_CUDA -NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( - const std::vector &local_scopes, - const std::vector &places, - const platform::NCCLContextMap *ctxs) +AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs) : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { - if (ctxs) { + if (nccl_ctxs_) { for (auto &p : places_) { this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); } } } #else -NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( - const std::vector &local_scopes, - const std::vector &places) +AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places) : local_scopes_(local_scopes), places_(places) {} #endif -void NCCLAllReduceOpHandle::RunImpl() { +void AllReduceOpHandle::RunImpl() { if (NoDummyInputSize() == 1) { return; // No need to all reduce when GPU count = 1; } else { @@ -133,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() { } } -std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } +std::string AllReduceOpHandle::Name() const { return "nccl_all_reduce"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h similarity index 79% rename from paddle/fluid/framework/details/nccl_all_reduce_op_handle.h rename to paddle/fluid/framework/details/all_reduce_op_handle.h index e0d206bd937adfca8a264d7ae3148ba3f1ae44d2..fdd250b0d3eb166249271a95f7592b9fadee5265 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -28,14 +28,14 @@ namespace paddle { namespace framework { namespace details { -struct NCCLAllReduceOpHandle : public OpHandleBase { +struct AllReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA - NCCLAllReduceOpHandle(const std::vector &local_scopes, - const std::vector &places, - const platform::NCCLContextMap *ctxs); + AllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs); #else - NCCLAllReduceOpHandle(const std::vector &local_scopes, - const std::vector &places); + AllReduceOpHandle(const std::vector &local_scopes, + const std::vector &places); #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 8a5f171ce5a0d77d4bec002535e5116035843dd4..35c5d0433fe231630c3f919c2b896f1e27729701 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -17,10 +17,10 @@ #include #include +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" -#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" @@ -283,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient( return false; } +void MultiDevSSAGraphBuilder::SetCommunicationContext( + OpHandleBase *op_handle, const platform::Place &p) const { +#ifdef PADDLE_WITH_CUDA + if (nccl_ctxs_ == nullptr) { + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); + } +#else + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); +#endif +} + void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const { @@ -306,19 +319,6 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, } } -void MultiDevSSAGraphBuilder::SetCommunicationContext( - OpHandleBase *op_handle, const platform::Place &p) const { -#ifdef PADDLE_WITH_CUDA - if (nccl_ctxs_ == nullptr) { - op_handle->SetDeviceContext(p, - platform::DeviceContextPool::Instance().Get(p)); - } -#else - op_handle->SetDeviceContext(p, - platform::DeviceContextPool::Instance().Get(p)); -#endif -} - void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, const OpDesc &op, int dev_id) const { @@ -331,9 +331,9 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->ops_.emplace_back( - new NCCLAllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new NCCLAllReduceOpHandle(local_scopes_, places_)); + result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_)); #endif auto *op_handle = result->ops_.back().get();