提交 e8a7e5d1 编写于 作者: Y Yu Yang

Update

上级 8f0590e7
......@@ -250,6 +250,8 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1;
size_t numel = 0;
platform::dynload::ncclGroupStart();
for (auto &p : member_->places_) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
......@@ -266,11 +268,12 @@ struct NCCLAllReduceOpHandle : public OpHandle {
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
ncclAllReduce(buffer, buffer, numel, static_cast<ncclDataType_t>(dtype),
ncclSum, nccl_ctx.comm, nccl_ctx.stream());
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
}
ncclGroupEnd();
platform::dynload::ncclGroupEnd();
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册