From 6f0107126a21b2e5e5df3be131531eeba33d7ef3 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 26 Jun 2018 20:16:24 +0800 Subject: [PATCH] fix broadcast bug --- paddle/fluid/framework/parallel_executor.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index dfebf36d04..485d89aa56 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -153,7 +153,6 @@ void ParallelExecutor::BCastParamsToGPUs( if (main_var == nullptr || !main_var->IsType()) { continue; } - VLOG(3) << "run broadcast " << var << " " << var_dev_id; auto &main_tensor = main_var->Get(); auto &dims = main_tensor.dims(); @@ -184,8 +183,16 @@ void ParallelExecutor::BCastParamsToGPUs( platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); - platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, - nccl_ctx.comm_, nccl_ctx.stream()); + if (initializing) { + platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, + nccl_ctx.comm_, nccl_ctx.stream()); + } else { + if (static_cast(var_dev_id)) { + platform::dynload::ncclBcast(buffers[i], numel, data_type, + var_dev_id, nccl_ctx.comm_, + nccl_ctx.stream()); + } + } } member_->nccl_ctxs_->WaitAll(); } -- GitLab