diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 991a0c8238cff60a51fb9a753a713d139064ff58..1823cefe42af3586f1118fb8d205f4e8b3ad13d9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -300,8 +300,6 @@ class ParallelExecutorPrivate { std::unique_ptr exception_; }; -static std::mutex g_nccl_mtx_; - struct NCCLAllReduceOpHandle : public OpHandle { ParallelExecutorPrivate *member_; @@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { int dtype = -1; size_t numel = 0; - std::lock_guard g(g_nccl_mtx_); - - PADDLE_ENFORCE(platform::dynload::ncclGroupStart()); + platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->local_scopes_.size(); ++i) { auto &p = member_->places_[i]; @@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle { buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm, nccl_ctx.stream())); } - PADDLE_ENFORCE(platform::dynload::ncclGroupEnd()); } } }; diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index e20f99bc6bc30298dc0ab6bb37adb6b855b6b75e..cceceda8ad83824eba6b50e061f4cd9e03d1b354 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" @@ -33,5 +34,24 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { } } +class NCCLGroupGuard { + public: + inline NCCLGroupGuard() { + mutex().lock(); + PADDLE_ENFORCE(dynload::ncclGroupStart()); + } + + inline ~NCCLGroupGuard() { + PADDLE_ENFORCE(dynload::ncclGroupEnd()); + mutex().unlock(); + } + + private: + static std::mutex& mutex() { + static std::mutex mtx; + return mtx; + } +}; + } // namespace platform } // namespace paddle