From d7badb3ed2d4fdcc42a81dffedf68e131daf5fdb Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 19 Mar 2018 19:33:35 +0800 Subject: [PATCH] Use event to sync stream --- paddle/fluid/framework/parallel_executor.cc | 30 ++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fa6763b5b5..6777aec488 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -315,9 +315,21 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { struct NCCLAllReduceOpHandle : public OpHandle { ParallelExecutorPrivate *member_; + std::vector events_; explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) - : member_(member) {} + : member_(member) { + events_.resize(member_->places_.size()); + for (auto &ev : events_) { + cudaEventCreateWithFlags(&ev, cudaEventDisableTiming); + } + } + + ~NCCLAllReduceOpHandle() { + for (auto &ev : events_) { + cudaEventDestroy(ev); + } + } void Run() override { if (this->inputs_.size() == 1) { @@ -350,6 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm, nccl_ctx.stream()); + cudaEventRecord(events_[i], nccl_ctx.stream()); } platform::dynload::ncclGroupEnd(); @@ -357,8 +370,19 @@ struct NCCLAllReduceOpHandle : public OpHandle { } void Wait(platform::DeviceContext *waited_dev) override { - for (auto &pair : member_->communication_streams_) { - pair.second.ctx_->Wait(); + if (platform::is_cpu_place( + waited_dev->GetPlace())) { // Wait by CPU, just sync stream + for (auto &pair : member_->communication_streams_) { + pair.second.ctx_->Wait(); + } + } else { + if (events_.size() > 1) { + auto stream = + static_cast(waited_dev)->stream(); + for (auto &ev : events_) { + cudaStreamWaitEvent(stream, ev, 0); + } + } } } }; -- GitLab