diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 116b13d3301d780875580867a8d53e91b2781145..f77a4b55a172d740514bd97db2b11fdb371d6af6 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -41,7 +41,7 @@ void NCCLAllReduceOpHandle::RunImpl() { int dtype = -1; size_t numel = 0; - platform::NCCLGroupGuard guard; + std::vector> all_reduce_calls; for (size_t i = 0; i < local_scopes_.size(); ++i) { auto &p = places_[i]; @@ -58,10 +58,20 @@ void NCCLAllReduceOpHandle::RunImpl() { if (numel == 0) { numel = static_cast(lod_tensor.numel()); } + auto &nccl_ctx = nccl_ctxs_.at(dev_id); - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - buffer, buffer, numel, static_cast(dtype), ncclSum, - nccl_ctx.comm_, nccl_ctx.stream())); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(dtype), ncclSum, + comm, stream)); + }); + } + + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); } } } diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index ecdd98987dd41985c7ac77a7fceb356e15f6cd3b..29990043206509e4192bfff84832f09ef127d9dd 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -36,10 +36,12 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { class NCCLGroupGuard { public: - inline NCCLGroupGuard() { PADDLE_ENFORCE(dynload::ncclGroupStart()); } + inline NCCLGroupGuard() { + mutex().lock(); + PADDLE_ENFORCE(dynload::ncclGroupStart()); + } inline ~NCCLGroupGuard() { - mutex().lock(); PADDLE_ENFORCE(dynload::ncclGroupEnd()); mutex().unlock(); }