diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fa6763b5b58e37e99858661739133e6476698928..6777aec488d72d2d75872098eb69a42972d94c57 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { struct NCCLAllReduceOpHandle : public OpHandle { ParallelExecutorPrivate *member_; + std::vector events_; explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) - : member_(member) {} + : member_(member) { + events_.resize(member_->places_.size()); + for (auto &ev : events_) { + cudaEventCreateWithFlags(&ev, cudaEventDisableTiming); + } + } + + ~NCCLAllReduceOpHandle() { + for (auto &ev : events_) { + cudaEventDestroy(ev); + } + } void Run() override { if (this->inputs_.size() == 1) { @@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm, nccl_ctx.stream()); + cudaEventRecord(events_[i], nccl_ctx.stream()); } platform::dynload::ncclGroupEnd(); @@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle { } void Wait(platform::DeviceContext *waited_dev) override { - for (auto &pair : member_->communication_streams_) { - pair.second.ctx_->Wait(); + if (platform::is_cpu_place( + waited_dev->GetPlace())) { // Wait by CPU, just sync stream + for (auto &pair : member_->communication_streams_) { + pair.second.ctx_->Wait(); + } + } else { + if (events_.size() > 1) { + auto stream = + static_cast(waited_dev)->stream(); + for (auto &ev : events_) { + cudaStreamWaitEvent(stream, ev, 0); + } + } } } };