From 41ad63234181e2c6dcec464db51c08270c18ac3c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Mar 2018 12:35:39 +0800 Subject: [PATCH] Add NCCL Group Guard --- paddle/fluid/framework/parallel_executor.cc | 7 +------ paddle/fluid/platform/nccl_helper.h | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 991a0c823..1823cefe4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -300,8 +300,6 @@ class ParallelExecutorPrivate { std::unique_ptr exception_; }; -static std::mutex g_nccl_mtx_; - struct NCCLAllReduceOpHandle : public OpHandle { ParallelExecutorPrivate *member_; @@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { int dtype = -1; size_t numel = 0; - std::lock_guard g(g_nccl_mtx_); - - PADDLE_ENFORCE(platform::dynload::ncclGroupStart()); + platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->local_scopes_.size(); ++i) { auto &p = member_->places_[i]; @@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle { buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm, nccl_ctx.stream())); } - PADDLE_ENFORCE(platform::dynload::ncclGroupEnd()); } } }; diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index e20f99bc6..cceceda8a 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" @@ -33,5 +34,24 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { } } +class NCCLGroupGuard { + public: + inline NCCLGroupGuard() { + mutex().lock(); + PADDLE_ENFORCE(dynload::ncclGroupStart()); + } + + inline ~NCCLGroupGuard() { + PADDLE_ENFORCE(dynload::ncclGroupEnd()); + mutex().unlock(); + } + + private: + static std::mutex& mutex() { + static std::mutex mtx; + return mtx; + } +}; + } // namespace platform } // namespace paddle -- GitLab