提交 7dcb217e 编写于 作者: Y Yu Yang

Refine allreduce op

上级 c0c2e159
......@@ -41,7 +41,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
int dtype = -1;
size_t numel = 0;
platform::NCCLGroupGuard guard;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
......@@ -58,10 +58,20 @@ void NCCLAllReduceOpHandle::RunImpl() {
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm_, nccl_ctx.stream()));
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
}
......
......@@ -36,10 +36,12 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
class NCCLGroupGuard {
public:
inline NCCLGroupGuard() { PADDLE_ENFORCE(dynload::ncclGroupStart()); }
inline NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupStart());
}
inline ~NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupEnd());
mutex().unlock();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册