提交 e8a7e5d1 编写于 作者: Y Yu Yang

Update

上级 8f0590e7
...@@ -250,6 +250,8 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -250,6 +250,8 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1; int dtype = -1;
size_t numel = 0; size_t numel = 0;
platform::dynload::ncclGroupStart();
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device; int dev_id = boost::get<platform::CUDAPlace>(p).device;
...@@ -266,11 +268,12 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -266,11 +268,12 @@ struct NCCLAllReduceOpHandle : public OpHandle {
auto &nccl_ctx = member_->communication_streams_.at(dev_id); auto &nccl_ctx = member_->communication_streams_.at(dev_id);
ncclAllReduce(buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), platform::dynload::ncclAllReduce(
ncclSum, nccl_ctx.comm, nccl_ctx.stream()); buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
} }
ncclGroupEnd(); platform::dynload::ncclGroupEnd();
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册