提交 7dcb217e 编写于 作者: Y Yu Yang

Refine allreduce op

上级 c0c2e159
...@@ -41,7 +41,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -41,7 +41,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
int dtype = -1; int dtype = -1;
size_t numel = 0; size_t numel = 0;
platform::NCCLGroupGuard guard; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
...@@ -58,10 +58,20 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -58,10 +58,20 @@ void NCCLAllReduceOpHandle::RunImpl() {
if (numel == 0) { if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel()); numel = static_cast<size_t>(lod_tensor.numel());
} }
auto &nccl_ctx = nccl_ctxs_.at(dev_id); auto &nccl_ctx = nccl_ctxs_.at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm_, nccl_ctx.stream())); comm, stream));
});
}
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
} }
} }
} }
......
...@@ -36,10 +36,12 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -36,10 +36,12 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
class NCCLGroupGuard { class NCCLGroupGuard {
public: public:
inline NCCLGroupGuard() { PADDLE_ENFORCE(dynload::ncclGroupStart()); } inline NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupStart());
}
inline ~NCCLGroupGuard() { inline ~NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupEnd()); PADDLE_ENFORCE(dynload::ncclGroupEnd());
mutex().unlock(); mutex().unlock();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册