From 8af57706e216131937b26ddbd83338883de0d5d1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 19 Mar 2018 19:44:31 +0800 Subject: [PATCH] Only wait same device --- paddle/fluid/framework/parallel_executor.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f7dc8339371..1d9584939fc 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -315,19 +315,19 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { struct NCCLAllReduceOpHandle : public OpHandle { ParallelExecutorPrivate *member_; - std::vector events_; + std::unordered_map events_; explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) : member_(member) { - events_.resize(member_->places_.size()); - for (auto &ev : events_) { - cudaEventCreateWithFlags(&ev, cudaEventDisableTiming); + for (auto &nccl : member_->communication_streams_) { + cudaEventCreate(&events_[nccl.second.device_id()], + cudaEventDisableTiming); } } ~NCCLAllReduceOpHandle() { for (auto &ev : events_) { - cudaEventDestroy(ev); + cudaEventDestroy(ev.second); } } @@ -362,7 +362,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, nccl_ctx.comm, nccl_ctx.stream()); - cudaEventRecord(events_[i], nccl_ctx.stream()); + cudaEventRecord(events_[dev_id], nccl_ctx.stream()); } platform::dynload::ncclGroupEnd(); @@ -377,11 +377,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { } } else { if (events_.size() > 1) { + int dev_id = + boost::get(waited_dev->GetPlace()).device; auto stream = static_cast(waited_dev)->stream(); - for (auto &ev : events_) { - cudaStreamWaitEvent(stream, ev, 0); - } + cudaStreamWaitEvent(stream, events_[dev_id], 0); } } } -- GitLab