提交 82693e72 编写于 作者: Y Yu Yang

Wait nccl all reduce

上级 eb0a580e
......@@ -348,6 +348,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {}
void Wait(platform::DeviceContext *waited_dev) override {
VLOG(3) << "Wait nccl all reduce op";
OpHandle::Wait(waited_dev);
}
protected:
void RunImpl() override {
if (this->inputs_.size() == 1) {
......@@ -381,7 +386,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册