From 82693e72273599da5a0ffc8e21790665279d4a4b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Mar 2018 19:14:27 +0800 Subject: [PATCH] Wait nccl all reduce --- paddle/fluid/framework/parallel_executor.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b8751662c3..8ee2e57324 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -348,6 +348,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) : member_(member) {} + void Wait(platform::DeviceContext *waited_dev) override { + VLOG(3) << "Wait nccl all reduce op"; + OpHandle::Wait(waited_dev); + } + protected: void RunImpl() override { if (this->inputs_.size() == 1) { @@ -381,7 +386,6 @@ struct NCCLAllReduceOpHandle : public OpHandle { if (numel == 0) { numel = static_cast(lod_tensor.numel()); } - auto &nccl_ctx = member_->communication_streams_.at(dev_id); PADDLE_ENFORCE(platform::dynload::ncclAllReduce( buffer, buffer, numel, static_cast(dtype), ncclSum, -- GitLab