未验证 提交 dd4cd352 编写于 作者: G gongweibao 提交者: GitHub

Fix sync_batch_norm_op ncclallreduce error! (#17918)

上级 f3e5a5cf
...@@ -281,6 +281,27 @@ bool ParallelExecutor::NeedCreateLocalExeScope() { ...@@ -281,6 +281,27 @@ bool ParallelExecutor::NeedCreateLocalExeScope() {
return executor && executor->NeedCreateLocalExeScope(); return executor && executor->NeedCreateLocalExeScope();
} }
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
/*
* When nccl inits nccl comm using ncclCommInitAll, it meets error when
* allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
* create a new nccl comm for sync_batch_norm_op. And these codes should be
* polished with a unified nccl management.
*/
platform::NCCLContextMap *ParallelExecutor::GetNCCLContextForSyncbatchNomrOp(
framework::Scope *scope) {
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
if (nccl_id_var != nullptr) {
return member_->nccl_ctxs_.DefaultFlatCtx();
}
if (dev_nccl_ctxs_.get() == nullptr) {
dev_nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
}
return dev_nccl_ctxs_.get();
}
#endif
ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const std::vector<std::string> &bcast_vars, const std::vector<std::string> &bcast_vars,
const std::string &loss_var_name, const std::string &loss_var_name,
...@@ -357,13 +378,13 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -357,13 +378,13 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// NOTE: NCCL group-calls and non-group-calls can not use the same // NOTE: NCCL group-calls and non-group-calls can not use the same
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// same communicators. // same communicators.
auto *nccl_ctxs = GetNCCLContextForSyncbatchNomrOp(scope);
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>( auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
pool.Get(member_->places_[dev_id])); pool.Get(member_->places_[dev_id]));
auto &nccl_ctx = auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]);
member_->nccl_ctxs_.DefaultFlatCtx()->at(member_->places_[dev_id]);
dev_ctx->set_nccl_comm(nccl_ctx.comm()); dev_ctx->set_nccl_comm(nccl_ctx.comm());
} }
#endif #endif
......
...@@ -87,7 +87,13 @@ class ParallelExecutor { ...@@ -87,7 +87,13 @@ class ParallelExecutor {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_; std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
// used for compatible with syncbatch norm op
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs_;
platform::NCCLContextMap *GetNCCLContextForSyncbatchNomrOp(
framework::Scope *scope);
#endif
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册