提交 41ad6323 编写于 作者: Y Yu Yang

Add NCCL Group Guard

上级 99fe83a0
......@@ -300,8 +300,6 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::EnforceNotMet> exception_;
};
static std::mutex g_nccl_mtx_;
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
......@@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1;
size_t numel = 0;
std::lock_guard<std::mutex> g(g_nccl_mtx_);
PADDLE_ENFORCE(platform::dynload::ncclGroupStart());
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
auto &p = member_->places_[i];
......@@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()));
}
PADDLE_ENFORCE(platform::dynload::ncclGroupEnd());
}
}
};
......
......@@ -14,6 +14,7 @@
#pragma once
#include <thread>
#include <typeindex>
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -33,5 +34,24 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
}
}
class NCCLGroupGuard {
public:
inline NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupStart());
}
inline ~NCCLGroupGuard() {
PADDLE_ENFORCE(dynload::ncclGroupEnd());
mutex().unlock();
}
private:
static std::mutex& mutex() {
static std::mutex mtx;
return mtx;
}
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册