diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 04ab58947af8f992714fd9e8e12a7a275696250b..2f001e54d4f668537953bbaeb14aa21e6745009f 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -35,7 +35,7 @@ namespace details { AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, - const platform::MultiNCCLContextMap *ctxs) + const platform::NCCLCommunicator *ctxs) : NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) { PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); } diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index 5ccf4291da6071768bee8b269e5e4c62f0798b71..f206f5fea5c41536a07143e707c53f135b287035 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -34,7 +34,7 @@ class AllReduceOpHandle : public NCCLOpHandleBase { public: AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, - const platform::MultiNCCLContextMap *ctxs); + const platform::NCCLCommunicator *ctxs); #else class AllReduceOpHandle : public OpHandleBase { public: diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 10cead16ea044e73c63ebba5b57915ed023ca777..3b57a099c8afeeca05f9fa45eda78e20197dc798 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -266,14 +266,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0; } -ir::Graph *BuildStrategy::Apply( - ir::Graph *graph, const std::vector &places, - const std::string &loss_var_name, const std::vector &local_scopes, - const size_t &nranks, +ir::Graph *BuildStrategy::Apply(ir::Graph *graph, + const std::vector &places, + const std::string &loss_var_name, + const std::vector &local_scopes, + const size_t &nranks, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - const bool use_cuda, platform::MultiNCCLContextMap *nccl_ctxs) const { + const bool use_cuda, + platform::NCCLCommunicator *nccl_ctxs) const { #else - const bool use_cuda) const { + const bool use_cuda) const { #endif VLOG(3) << "apply all passes"; // Create a default one if not finalized by user. @@ -293,9 +295,9 @@ ir::Graph *BuildStrategy::Apply( pass->Set(ir::kNRanks, new size_t(nranks)); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - platform::MultiNCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; pass->Erase(kNCCLCtxs); - pass->SetNotOwned(kNCCLCtxs, nctx); + pass->SetNotOwned(kNCCLCtxs, nctx); #endif } else if (pass->Type() == "alloc_continuous_space_for_grad_pass" || pass->Type() == "fuse_adam_op_pass" || @@ -309,9 +311,9 @@ ir::Graph *BuildStrategy::Apply( &local_scopes); if (pass->Type() == "fuse_all_reduce_op_pass") { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - platform::MultiNCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; pass->Erase(kNCCLCtxs); - pass->SetNotOwned(kNCCLCtxs, nctx); + pass->SetNotOwned(kNCCLCtxs, nctx); pass->Erase(kUseHierarchicalAllReduce); pass->Set(kUseHierarchicalAllReduce, new bool(use_hierarchical_allreduce_)); @@ -328,9 +330,9 @@ ir::Graph *BuildStrategy::Apply( << enable_sequential_execution_; } else if (pass->Type() == "all_reduce_deps_pass") { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - platform::MultiNCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr; pass->Erase(kNCCLCtxs); - pass->SetNotOwned(kNCCLCtxs, nctx); + pass->SetNotOwned(kNCCLCtxs, nctx); pass->Erase(kUseHierarchicalAllReduce); pass->Set(kUseHierarchicalAllReduce, new bool(use_hierarchical_allreduce_)); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index bf698edaff5151819a4953ce288f60da6466153b..8eaace17bb1a59bc5033e632511886c7630d0cd2 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -149,7 +149,7 @@ struct BuildStrategy { const size_t &nranks, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) const bool use_cuda, - platform::MultiNCCLContextMap *nccl_ctxs) const; + platform::NCCLCommunicator *nccl_ctxs) const; #else const bool use_cuda) const; #endif diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index 4f27b7acff63170958f2d2e83399279ca7b340b2..4d96d820a1d161e76945a1c87e1832d95a8a802e 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -44,7 +44,7 @@ typedef std::vector>> FusedAllReduceOpHandle::FusedAllReduceOpHandle( ir::Node *node, const std::vector &local_scopes, const std::vector &places, const size_t num_of_all_reduce, - const platform::MultiNCCLContextMap *ctxs) + const platform::NCCLCommunicator *ctxs) : NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes), num_of_all_reduce_(num_of_all_reduce) { diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h index 00730f107595bdd02f918eb5088efcd0de11964b..e0b9123c5b7e40f7d96ef3ea4061c2822aca7eef 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h @@ -35,7 +35,7 @@ struct FusedAllReduceOpHandle : public NCCLOpHandleBase { const std::vector &local_scopes, const std::vector &places, const size_t num_of_all_reduce, - const platform::MultiNCCLContextMap *ctxs); + const platform::NCCLCommunicator *ctxs); #else struct FusedAllReduceOpHandle : public OpHandleBase { FusedAllReduceOpHandle(ir::Node *node, diff --git a/paddle/fluid/framework/details/nccl_op_handle.h b/paddle/fluid/framework/details/nccl_op_handle.h index 7f9de6e2f012ea6b81833721cf75666b323fe9f7..2f425372234898860521570da8884497c995e9e2 100644 --- a/paddle/fluid/framework/details/nccl_op_handle.h +++ b/paddle/fluid/framework/details/nccl_op_handle.h @@ -33,7 +33,7 @@ namespace details { class NCCLOpHandleBase : public OpHandleBase { public: NCCLOpHandleBase(ir::Node* node, const std::vector& places, - const platform::MultiNCCLContextMap* nccl_ctxs) + const platform::NCCLCommunicator* nccl_ctxs) : OpHandleBase(node), places_(places), nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs == nullptr) { return; @@ -215,7 +215,7 @@ class NCCLOpHandleBase : public OpHandleBase { protected: std::vector places_; - const platform::MultiNCCLContextMap* nccl_ctxs_{nullptr}; + const platform::NCCLCommunicator* nccl_ctxs_{nullptr}; // When multi trainer call collective function, they need run the same order. // Or the program will hang.So we use allreduce_deps_pass to set this // run_order_. diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc index 5c7d6db30410231472fda4688692146c3560e521..cc3493d849eccbecf3d039dc7b2fc18575fcf9d0 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc @@ -30,7 +30,7 @@ namespace details { SparseAllReduceOpHandle::SparseAllReduceOpHandle( ir::Node *node, const std::vector &local_scopes, const std::vector &places, - const platform::MultiNCCLContextMap *ctxs, bool is_encoded, int nranks) + const platform::NCCLCommunicator *ctxs, bool is_encoded, int nranks) : AllReduceOpHandle(node, local_scopes, places, ctxs), is_encoded_(is_encoded), nranks_(nranks) { diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h index b3ff6cd392453e54e7e0f85fe417ed428ca19a95..9802f8dba7e05aec424f48d50992d065015179c9 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.h @@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { SparseAllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, - const platform::MultiNCCLContextMap *ctxs, + const platform::NCCLCommunicator *ctxs, bool is_encoded = false, int nranks = -1); std::string Name() const override; diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc index a7c492f0ce9a8953558c5e6236602a312badba79..abfaf1b8d201450ca211911fe4b527948b4ac7e4 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc @@ -35,7 +35,7 @@ class FuseAllReduceOpPass : public ir::Pass { auto &local_scopes = Get>(details::kLocalScopes); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto *multi_nccl_ctxs = - &Get(details::kNCCLCtxs); + &Get(details::kNCCLCtxs); #endif std::unordered_set grads; @@ -103,14 +103,14 @@ class FuseAllReduceOpPass : public ir::Pass { } } - void InsertFusedAllReduce( - const std::vector &places, - const std::vector &local_scopes, const size_t num_of_all_reduce, - const std::vector &all_reduce_ops, + void InsertFusedAllReduce(const std::vector &places, + const std::vector &local_scopes, + const size_t num_of_all_reduce, + const std::vector &all_reduce_ops, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - const platform::MultiNCCLContextMap *multi_nccl_ctxs, + const platform::NCCLCommunicator *multi_nccl_ctxs, #endif - ir::Graph *result) const { + ir::Graph *result) const { std::vector inputs; std::vector outputs; for (auto &op : all_reduce_ops) { @@ -151,7 +151,7 @@ class FuseAllReduceOpPass : public ir::Pass { const std::vector &places, const std::vector &local_scopes, #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - const platform::MultiNCCLContextMap *multi_nccl_ctxs, + const platform::NCCLCommunicator *multi_nccl_ctxs, #endif ir::Graph *result) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc index 6127f6ac23822f6ad93952ae503f69eb8c6fec96..d6d9c8bb891807e0a229959b00479482fe544e7a 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc @@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { local_scopes_ = Get>(details::kLocalScopes); strategy_ = Get(kStrategy); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - multi_nccl_ctxs_ = &Get(details::kNCCLCtxs); + multi_nccl_ctxs_ = &Get(details::kNCCLCtxs); nccl_ctxs_ = nullptr; if (multi_nccl_ctxs_) { nccl_ctxs_ = multi_nccl_ctxs_->DefaultFlatCtx(); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h index 278621bf6f443f9f3b1e90beff261d89a48abc62..9b36d231081d4922419881fd115b3ca347d7d064 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h @@ -97,7 +97,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) mutable platform::NCCLContextMap *nccl_ctxs_{nullptr}; - mutable platform::MultiNCCLContextMap *multi_nccl_ctxs_{nullptr}; + mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr}; #endif mutable std::string loss_var_name_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f5ab5d6ee5dc800632febc38184850b1fbb52284..07479f6782857f5b7a32743e2baa8b7509e63a3a 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -111,8 +111,8 @@ class ParallelExecutorPrivate { std::vector flat_nccl_ids; if (nranks_ == 1) { // FIXME(gongwb): need not to create ncclid when nranks==1 - nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, - bst.trainer_id_); + nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, + bst.trainer_id_); return; } @@ -132,16 +132,16 @@ class ParallelExecutorPrivate { flat_nccl_ids.push_back(nccl_id); - nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, - bst.trainer_id_); + nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, + bst.trainer_id_); VLOG(1) << "init bst nccl context complete!"; return; } // num_trainers ==1 && places > 1 if (bst.num_trainers_ == 1) { - nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, - bst.trainer_id_); + nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, + bst.trainer_id_); return; } @@ -153,8 +153,8 @@ class ParallelExecutorPrivate { flat_nccl_ids.push_back(nccl_id); } - nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, - bst.trainer_id_); + nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, + bst.trainer_id_); if (bst.use_hierarchical_allreduce_) { std::vector inter_nccl_ids; @@ -175,12 +175,30 @@ class ParallelExecutorPrivate { exter_nccl_ids.push_back(nccl_id); } - nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_ids, exter_nccl_ids, - bst.num_trainers_, bst.trainer_id_, - bst.hierarchical_allreduce_inter_nranks_, - bst.hierarchical_allreduce_exter_nranks_); + nccl_ctxs_->InitHierarchicalCtxs( + places_, inter_nccl_ids, exter_nccl_ids, bst.num_trainers_, + bst.trainer_id_, bst.hierarchical_allreduce_inter_nranks_, + bst.hierarchical_allreduce_exter_nranks_); } } + + void InitOrGetNCCLCommunicator(framework::Scope *scope, + const BuildStrategy &bst) { + const std::string var_name = "NCCLCommunicator"; + auto var = scope->FindVar(var_name); + if (var != nullptr) { + PADDLE_ENFORCE(var->IsInitialized(), + "if %s exists, it must be initialized", var_name); + VLOG(1) << "find " << var_name + << " in scope, so use it and does not recreate!"; + nccl_ctxs_ = var->GetMutable(); + return; + } + + VLOG(1) << "not find " << var_name << " in scope, so recreate it!"; + nccl_ctxs_ = scope->Var(var_name)->GetMutable(); + InitNCCLCtxs(scope, bst); + } #endif BuildStrategy build_strategy_; @@ -190,7 +208,7 @@ class ParallelExecutorPrivate { std::unique_ptr executor_; #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - platform::MultiNCCLContextMap nccl_ctxs_; + platform::NCCLCommunicator *nccl_ctxs_{nullptr}; #endif bool own_local_scope_; bool use_cuda_; @@ -281,27 +299,6 @@ bool ParallelExecutor::NeedCreateLocalExeScope() { return executor && executor->NeedCreateLocalExeScope(); } -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -/* - * When nccl inits nccl comm using ncclCommInitAll, it meets error when - * allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So - * create a new nccl comm for sync_batch_norm_op. And these codes should be - * polished with a unified nccl management. - */ -platform::NCCLContextMap *ParallelExecutor::GetNCCLContextForSyncbatchNomrOp( - framework::Scope *scope) { - auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); - if (nccl_id_var != nullptr) { - return member_->nccl_ctxs_.DefaultFlatCtx(); - } - - if (dev_nccl_ctxs_.get() == nullptr) { - dev_nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_)); - } - return dev_nccl_ctxs_.get(); -} -#endif - ParallelExecutor::ParallelExecutor(const std::vector &places, const std::vector &bcast_vars, const std::string &loss_var_name, @@ -369,7 +366,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, if (member_->use_cuda_) { // Bcast Parameters to all GPUs #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - member_->InitNCCLCtxs(scope, build_strategy); + member_->InitOrGetNCCLCommunicator(scope, build_strategy); // Initialize device context's nccl comm, will be used by normal // Operators like sync_batch_norm, and collective ops. @@ -378,7 +375,8 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // NOTE: NCCL group-calls and non-group-calls can not use the same // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use // same communicators. - auto *nccl_ctxs = GetNCCLContextForSyncbatchNomrOp(scope); + auto *nccl_ctxs = + member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_); for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); @@ -415,18 +413,18 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, VLOG(3) << "use local async mode"; graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name, {member_->local_scopes_[0]}, 1, - member_->use_cuda_, &member_->nccl_ctxs_); + member_->use_cuda_, member_->nccl_ctxs_); for (size_t i = 1; i < member_->places_.size(); ++i) { graphs[i] = build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name, {member_->local_scopes_[i]}, 1, - member_->use_cuda_, &member_->nccl_ctxs_); + member_->use_cuda_, member_->nccl_ctxs_); async_graphs[i] = graphs[i]; } } else { graph = build_strategy.Apply(graph, member_->places_, loss_var_name, member_->local_scopes_, member_->nranks_, - member_->use_cuda_, &member_->nccl_ctxs_); + member_->use_cuda_, member_->nccl_ctxs_); } #else if (build_strategy.async_mode_) { @@ -559,7 +557,7 @@ void ParallelExecutor::BCastParamsToDevices( PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(), "variables' buffer size to bcast NOT equal to places"); { - auto *nccl_ctxs = member_->nccl_ctxs_.DefaultFlatCtx(); + auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx(); platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = nccl_ctxs->at(member_->places_[i]); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 89a48b303dd6bf1c5a60e2baec2c69ff4dd3fc3b..6943fe62b915e0707dfe40ecbda90f61464338cf 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -87,13 +87,6 @@ class ParallelExecutor { ParallelExecutorPrivate *member_; std::vector> async_graphs_; - -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - // used for compatible with syncbatch norm op - std::unique_ptr dev_nccl_ctxs_; - platform::NCCLContextMap *GetNCCLContextForSyncbatchNomrOp( - framework::Scope *scope); -#endif }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index a37b1fbab8cfd0642beaf725c02941002b2176b3..7cc2b3b42258942e6016486f7cf7ecfcae92b91c 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/var_type_traits.h" +#include #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/scope.h" @@ -22,6 +23,7 @@ #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" +#include "paddle/fluid/platform/nccl_helper.h" #endif #include #include "paddle/fluid/operators/conv_cudnn_op_cache.h" diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index fa77b96a7bdfa28ed982db022e8e5ecaef0b443c..7147f06233cb9d435d8be62814df0a3891b729fb 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -36,6 +36,7 @@ namespace platform { #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 class Communicator; +class NCCLCommunicator; #endif #endif } // namespace platform @@ -140,7 +141,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< std::map, operators::reader::LoDTensorBlockingQueueHolder, #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 - ncclUniqueId, platform::Communicator, + ncclUniqueId, platform::Communicator, platform::NCCLCommunicator, #endif operators::CudnnRNNCache, #endif diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index a47275e1ca25a4f66e67b4986ec78e49ea952a51..67dbfd740ed9b71fa06b684c14720ae2814fe11c 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -26,6 +26,7 @@ #ifdef PADDLE_WITH_CUDA #ifndef _WIN32 #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" +#include "paddle/fluid/platform/nccl_helper.h" #endif #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/cudnn_rnn_cache.h" diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 18bc17f5c483a6c8907c53ff0c7bda38eebb566b..d79ff6e2b98a3fb3722198b67785b41a83fcb7cd 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -176,10 +176,10 @@ inline std::string GetHierarchicalInterNCCLVarName(size_t pos) { static_cast(pos)); } -class MultiNCCLContextMap { +class NCCLCommunicator { public: - MultiNCCLContextMap() {} - virtual ~MultiNCCLContextMap() {} + NCCLCommunicator() {} + virtual ~NCCLCommunicator() {} NCCLContextMap *DefaultFlatCtx() const { if (flat_ctxs_.size() == 0) { @@ -206,6 +206,25 @@ class MultiNCCLContextMap { return GetHierarchicalInterCtx(run_order); } + /* + *When nccl inits nccl comm using ncclCommInitAll, it meets error when + *allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So + *create a new nccl comm for sync_batch_norm_op. And these codes should be + *polished with a unified nccl management. + */ + NCCLContextMap *GetSyncBatchNormCtx( + framework::Scope *scope, const std::vector &places) { + auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); + if (nccl_id_var != nullptr) { + return DefaultFlatCtx(); + } + + if (sync_batch_norm_ctx_.get() == nullptr) { + sync_batch_norm_ctx_.reset(new NCCLContextMap(places)); + } + return sync_batch_norm_ctx_.get(); + } + void InitFlatCtxs(const std::vector &places, const std::vector &nccl_ids, size_t trainers_num, size_t trainer_id) { @@ -290,6 +309,9 @@ class MultiNCCLContextMap { // And h_exter_ctxs_ can support multi comm too. std::vector> h_inter_ctxs_; std::vector> h_exter_ctxs_; + + // just used for sync_batch_norm op. + std::unique_ptr sync_batch_norm_ctx_; }; } // namespace platform diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 985215f9dc08c4ec8ea4f5410b72d24a0138df6d..6b88325d705cf2bade5362b69cf1c5d54f061967 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -167,6 +167,15 @@ class TestDistRunnerBase(object): build_strategy=build_stra, exec_strategy=exec_strategy) + if args.use_cuda and args.update_method == "nccl2": + # it just for test share_vars_from feature. + test_exe = fluid.ParallelExecutor( + use_cuda=True, + loss_name=avg_cost.name, + build_strategy=build_stra, + main_program=test_program, + share_vars_from=binary._executor) + feed_var_list = [ var for var in trainer_prog.global_block().vars.values() if var.is_data