提交 41ad6323 编写于 作者: Y Yu Yang

Add NCCL Group Guard

上级 99fe83a0
...@@ -300,8 +300,6 @@ class ParallelExecutorPrivate { ...@@ -300,8 +300,6 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
}; };
static std::mutex g_nccl_mtx_;
struct NCCLAllReduceOpHandle : public OpHandle { struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
...@@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1; int dtype = -1;
size_t numel = 0; size_t numel = 0;
std::lock_guard<std::mutex> g(g_nccl_mtx_); platform::NCCLGroupGuard guard;
PADDLE_ENFORCE(platform::dynload::ncclGroupStart());
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) { for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
auto &p = member_->places_[i]; auto &p = member_->places_[i];
...@@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream())); nccl_ctx.comm, nccl_ctx.stream()));
} }
PADDLE_ENFORCE(platform::dynload::ncclGroupEnd());
} }
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <thread>
#include <typeindex> #include <typeindex>
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -33,5 +34,24 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -33,5 +34,24 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
} }
} }
class NCCLGroupGuard {
public:
inline NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupStart());
}
inline ~NCCLGroupGuard() {
PADDLE_ENFORCE(dynload::ncclGroupEnd());
mutex().unlock();
}
private:
static std::mutex& mutex() {
static std::mutex mtx;
return mtx;
}
};
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册