From 7dcb217e3147642221b65fd20820010ebe78d316 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 27 Mar 2018 15:54:12 +0800 Subject: [PATCH] Refine allreduce op --- .../details/nccl_all_reduce_op_handle.cc | 18 ++++++++++++++---- paddle/fluid/platform/nccl_helper.h | 6 ++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 116b13d33..f77a4b55a 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -41,7 +41,7 @@ void NCCLAllReduceOpHandle::RunImpl() { int dtype = -1; size_t numel = 0; - platform::NCCLGroupGuard guard; + std::vector> all_reduce_calls; for (size_t i = 0; i < local_scopes_.size(); ++i) { auto &p = places_[i]; @@ -58,10 +58,20 @@ void NCCLAllReduceOpHandle::RunImpl() { if (numel == 0) { numel = static_cast(lod_tensor.numel()); } + auto &nccl_ctx = nccl_ctxs_.at(dev_id); - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - buffer, buffer, numel, static_cast(dtype), ncclSum, - nccl_ctx.comm_, nccl_ctx.stream())); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), ncclSum, + comm, stream)); + }); + } + + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); } } } diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index ecdd98987..299900432 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -36,10 +36,12 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { class NCCLGroupGuard { public: - inline NCCLGroupGuard() { PADDLE_ENFORCE(dynload::ncclGroupStart()); } + inline NCCLGroupGuard() { + mutex().lock(); + PADDLE_ENFORCE(dynload::ncclGroupStart()); + } inline ~NCCLGroupGuard() { - mutex().lock(); PADDLE_ENFORCE(dynload::ncclGroupEnd()); mutex().unlock(); } -- GitLab