diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b8751662c36626c4bde293727ce4eb0af5528211..8ee2e573241315a086500442bdbf7ee6e0185718 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -348,6 +348,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) : member_(member) {} + void Wait(platform::DeviceContext *waited_dev) override { + VLOG(3) << "Wait nccl all reduce op"; + OpHandle::Wait(waited_dev); + } + protected: void RunImpl() override { if (this->inputs_.size() == 1) { @@ -381,7 +386,6 @@ struct NCCLAllReduceOpHandle : public OpHandle { if (numel == 0) { numel = static_cast(lod_tensor.numel()); } - auto &nccl_ctx = member_->communication_streams_.at(dev_id); PADDLE_ENFORCE(platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum,