From dd4cd352c7d928e643ab1c2c2e3bb82062bf34a1 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Sat, 8 Jun 2019 21:31:42 +0800 Subject: [PATCH] Fix sync_batch_norm_op ncclallreduce error! (#17918) --- paddle/fluid/framework/parallel_executor.cc | 25 +++++++++++++++++++-- paddle/fluid/framework/parallel_executor.h | 8 ++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index cc2b95735ef..f5ab5d6ee5d 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -281,6 +281,27 @@ bool ParallelExecutor::NeedCreateLocalExeScope() { return executor && executor->NeedCreateLocalExeScope(); } +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +/* + * When nccl inits nccl comm using ncclCommInitAll, it meets error when + * allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So + * create a new nccl comm for sync_batch_norm_op. And these codes should be + * polished with a unified nccl management. + */ +platform::NCCLContextMap *ParallelExecutor::GetNCCLContextForSyncbatchNomrOp( + framework::Scope *scope) { + auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); + if (nccl_id_var != nullptr) { + return member_->nccl_ctxs_.DefaultFlatCtx(); + } + + if (dev_nccl_ctxs_.get() == nullptr) { + dev_nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); + } + return dev_nccl_ctxs_.get(); +} +#endif + ParallelExecutor::ParallelExecutor(const std::vector &places, const std::vector &bcast_vars, const std::string &loss_var_name, @@ -357,13 +378,13 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // NOTE: NCCL group-calls and non-group-calls can not use the same // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use // same communicators. + auto *nccl_ctxs = GetNCCLContextForSyncbatchNomrOp(scope); for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto *dev_ctx = static_cast( pool.Get(member_->places_[dev_id])); - auto &nccl_ctx = - member_->nccl_ctxs_.DefaultFlatCtx()->at(member_->places_[dev_id]); + auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]); dev_ctx->set_nccl_comm(nccl_ctx.comm()); } #endif diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index bdd323bd160..89a48b303dd 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -87,7 +87,13 @@ class ParallelExecutor { ParallelExecutorPrivate *member_; std::vector> async_graphs_; -}; +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + // used for compatible with syncbatch norm op + std::unique_ptr dev_nccl_ctxs_; + platform::NCCLContextMap *GetNCCLContextForSyncbatchNomrOp( + framework::Scope *scope); +#endif +}; } // namespace framework } // namespace paddle -- GitLab