未验证 提交 f5caf344 编写于 作者: G gongweibao 提交者: GitHub

Fix reinitialized ncclid error! (#18025)

上级 354643d8
...@@ -35,7 +35,7 @@ namespace details { ...@@ -35,7 +35,7 @@ namespace details {
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::MultiNCCLContextMap *ctxs) const platform::NCCLCommunicator *ctxs)
: NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) { : NCCLOpHandleBase(node, places, ctxs), local_scopes_(local_scopes) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
} }
......
...@@ -34,7 +34,7 @@ class AllReduceOpHandle : public NCCLOpHandleBase { ...@@ -34,7 +34,7 @@ class AllReduceOpHandle : public NCCLOpHandleBase {
public: public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::MultiNCCLContextMap *ctxs); const platform::NCCLCommunicator *ctxs);
#else #else
class AllReduceOpHandle : public OpHandleBase { class AllReduceOpHandle : public OpHandleBase {
public: public:
......
...@@ -266,14 +266,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const { ...@@ -266,14 +266,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0; return framework::ir::MultiDevSSAGraphBuilder().count(pass_name) > 0;
} }
ir::Graph *BuildStrategy::Apply( ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
ir::Graph *graph, const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes, const std::string &loss_var_name,
const size_t &nranks, const std::vector<Scope *> &local_scopes,
const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::MultiNCCLContextMap *nccl_ctxs) const { const bool use_cuda,
platform::NCCLCommunicator *nccl_ctxs) const {
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes"; VLOG(3) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
...@@ -293,9 +295,9 @@ ir::Graph *BuildStrategy::Apply( ...@@ -293,9 +295,9 @@ ir::Graph *BuildStrategy::Apply(
pass->Set<size_t>(ir::kNRanks, new size_t(nranks)); pass->Set<size_t>(ir::kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::MultiNCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::MultiNCCLContextMap>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
#endif #endif
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" || } else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" || pass->Type() == "fuse_adam_op_pass" ||
...@@ -309,9 +311,9 @@ ir::Graph *BuildStrategy::Apply( ...@@ -309,9 +311,9 @@ ir::Graph *BuildStrategy::Apply(
&local_scopes); &local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") { if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::MultiNCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::MultiNCCLContextMap>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce); pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce, pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_)); new bool(use_hierarchical_allreduce_));
...@@ -328,9 +330,9 @@ ir::Graph *BuildStrategy::Apply( ...@@ -328,9 +330,9 @@ ir::Graph *BuildStrategy::Apply(
<< enable_sequential_execution_; << enable_sequential_execution_;
} else if (pass->Type() == "all_reduce_deps_pass") { } else if (pass->Type() == "all_reduce_deps_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::MultiNCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; platform::NCCLCommunicator *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::MultiNCCLContextMap>(kNCCLCtxs, nctx); pass->SetNotOwned<platform::NCCLCommunicator>(kNCCLCtxs, nctx);
pass->Erase(kUseHierarchicalAllReduce); pass->Erase(kUseHierarchicalAllReduce);
pass->Set<bool>(kUseHierarchicalAllReduce, pass->Set<bool>(kUseHierarchicalAllReduce,
new bool(use_hierarchical_allreduce_)); new bool(use_hierarchical_allreduce_));
......
...@@ -149,7 +149,7 @@ struct BuildStrategy { ...@@ -149,7 +149,7 @@ struct BuildStrategy {
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, const bool use_cuda,
platform::MultiNCCLContextMap *nccl_ctxs) const; platform::NCCLCommunicator *nccl_ctxs) const;
#else #else
const bool use_cuda) const; const bool use_cuda) const;
#endif #endif
......
...@@ -44,7 +44,7 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>> ...@@ -44,7 +44,7 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
FusedAllReduceOpHandle::FusedAllReduceOpHandle( FusedAllReduceOpHandle::FusedAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce, const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
const platform::MultiNCCLContextMap *ctxs) const platform::NCCLCommunicator *ctxs)
: NCCLOpHandleBase(node, places, ctxs), : NCCLOpHandleBase(node, places, ctxs),
local_scopes_(local_scopes), local_scopes_(local_scopes),
num_of_all_reduce_(num_of_all_reduce) { num_of_all_reduce_(num_of_all_reduce) {
......
...@@ -35,7 +35,7 @@ struct FusedAllReduceOpHandle : public NCCLOpHandleBase { ...@@ -35,7 +35,7 @@ struct FusedAllReduceOpHandle : public NCCLOpHandleBase {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const platform::MultiNCCLContextMap *ctxs); const platform::NCCLCommunicator *ctxs);
#else #else
struct FusedAllReduceOpHandle : public OpHandleBase { struct FusedAllReduceOpHandle : public OpHandleBase {
FusedAllReduceOpHandle(ir::Node *node, FusedAllReduceOpHandle(ir::Node *node,
......
...@@ -33,7 +33,7 @@ namespace details { ...@@ -33,7 +33,7 @@ namespace details {
class NCCLOpHandleBase : public OpHandleBase { class NCCLOpHandleBase : public OpHandleBase {
public: public:
NCCLOpHandleBase(ir::Node* node, const std::vector<platform::Place>& places, NCCLOpHandleBase(ir::Node* node, const std::vector<platform::Place>& places,
const platform::MultiNCCLContextMap* nccl_ctxs) const platform::NCCLCommunicator* nccl_ctxs)
: OpHandleBase(node), places_(places), nccl_ctxs_(nccl_ctxs) { : OpHandleBase(node), places_(places), nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs == nullptr) { if (nccl_ctxs == nullptr) {
return; return;
...@@ -215,7 +215,7 @@ class NCCLOpHandleBase : public OpHandleBase { ...@@ -215,7 +215,7 @@ class NCCLOpHandleBase : public OpHandleBase {
protected: protected:
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
const platform::MultiNCCLContextMap* nccl_ctxs_{nullptr}; const platform::NCCLCommunicator* nccl_ctxs_{nullptr};
// When multi trainer call collective function, they need run the same order. // When multi trainer call collective function, they need run the same order.
// Or the program will hang.So we use allreduce_deps_pass to set this // Or the program will hang.So we use allreduce_deps_pass to set this
// run_order_. // run_order_.
......
...@@ -30,7 +30,7 @@ namespace details { ...@@ -30,7 +30,7 @@ namespace details {
SparseAllReduceOpHandle::SparseAllReduceOpHandle( SparseAllReduceOpHandle::SparseAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::MultiNCCLContextMap *ctxs, bool is_encoded, int nranks) const platform::NCCLCommunicator *ctxs, bool is_encoded, int nranks)
: AllReduceOpHandle(node, local_scopes, places, ctxs), : AllReduceOpHandle(node, local_scopes, places, ctxs),
is_encoded_(is_encoded), is_encoded_(is_encoded),
nranks_(nranks) { nranks_(nranks) {
......
...@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle { ...@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
SparseAllReduceOpHandle(ir::Node *node, SparseAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::MultiNCCLContextMap *ctxs, const platform::NCCLCommunicator *ctxs,
bool is_encoded = false, int nranks = -1); bool is_encoded = false, int nranks = -1);
std::string Name() const override; std::string Name() const override;
......
...@@ -35,7 +35,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -35,7 +35,7 @@ class FuseAllReduceOpPass : public ir::Pass {
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes); auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *multi_nccl_ctxs = auto *multi_nccl_ctxs =
&Get<platform::MultiNCCLContextMap>(details::kNCCLCtxs); &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
#endif #endif
std::unordered_set<std::string> grads; std::unordered_set<std::string> grads;
...@@ -103,14 +103,14 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -103,14 +103,14 @@ class FuseAllReduceOpPass : public ir::Pass {
} }
} }
void InsertFusedAllReduce( void InsertFusedAllReduce(const std::vector<platform::Place> &places,
const std::vector<platform::Place> &places, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_scopes, const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const std::vector<ir::Node *> &all_reduce_ops, const std::vector<ir::Node *> &all_reduce_ops,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::MultiNCCLContextMap *multi_nccl_ctxs, const platform::NCCLCommunicator *multi_nccl_ctxs,
#endif #endif
ir::Graph *result) const { ir::Graph *result) const {
std::vector<details::VarHandleBase *> inputs; std::vector<details::VarHandleBase *> inputs;
std::vector<details::VarHandleBase *> outputs; std::vector<details::VarHandleBase *> outputs;
for (auto &op : all_reduce_ops) { for (auto &op : all_reduce_ops) {
...@@ -151,7 +151,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -151,7 +151,7 @@ class FuseAllReduceOpPass : public ir::Pass {
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const platform::MultiNCCLContextMap *multi_nccl_ctxs, const platform::NCCLCommunicator *multi_nccl_ctxs,
#endif #endif
ir::Graph *result) const { ir::Graph *result) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
...@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { ...@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes); local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes);
strategy_ = Get<const details::BuildStrategy>(kStrategy); strategy_ = Get<const details::BuildStrategy>(kStrategy);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
multi_nccl_ctxs_ = &Get<platform::MultiNCCLContextMap>(details::kNCCLCtxs); multi_nccl_ctxs_ = &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
nccl_ctxs_ = nullptr; nccl_ctxs_ = nullptr;
if (multi_nccl_ctxs_) { if (multi_nccl_ctxs_) {
nccl_ctxs_ = multi_nccl_ctxs_->DefaultFlatCtx(); nccl_ctxs_ = multi_nccl_ctxs_->DefaultFlatCtx();
......
...@@ -97,7 +97,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -97,7 +97,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable platform::NCCLContextMap *nccl_ctxs_{nullptr}; mutable platform::NCCLContextMap *nccl_ctxs_{nullptr};
mutable platform::MultiNCCLContextMap *multi_nccl_ctxs_{nullptr}; mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr};
#endif #endif
mutable std::string loss_var_name_; mutable std::string loss_var_name_;
......
...@@ -111,8 +111,8 @@ class ParallelExecutorPrivate { ...@@ -111,8 +111,8 @@ class ParallelExecutorPrivate {
std::vector<ncclUniqueId *> flat_nccl_ids; std::vector<ncclUniqueId *> flat_nccl_ids;
if (nranks_ == 1) { if (nranks_ == 1) {
// FIXME(gongwb): need not to create ncclid when nranks==1 // FIXME(gongwb): need not to create ncclid when nranks==1
nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
bst.trainer_id_); bst.trainer_id_);
return; return;
} }
...@@ -132,16 +132,16 @@ class ParallelExecutorPrivate { ...@@ -132,16 +132,16 @@ class ParallelExecutorPrivate {
flat_nccl_ids.push_back(nccl_id); flat_nccl_ids.push_back(nccl_id);
nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
bst.trainer_id_); bst.trainer_id_);
VLOG(1) << "init bst nccl context complete!"; VLOG(1) << "init bst nccl context complete!";
return; return;
} }
// num_trainers ==1 && places > 1 // num_trainers ==1 && places > 1
if (bst.num_trainers_ == 1) { if (bst.num_trainers_ == 1) {
nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
bst.trainer_id_); bst.trainer_id_);
return; return;
} }
...@@ -153,8 +153,8 @@ class ParallelExecutorPrivate { ...@@ -153,8 +153,8 @@ class ParallelExecutorPrivate {
flat_nccl_ids.push_back(nccl_id); flat_nccl_ids.push_back(nccl_id);
} }
nccl_ctxs_.InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_, nccl_ctxs_->InitFlatCtxs(places_, flat_nccl_ids, bst.num_trainers_,
bst.trainer_id_); bst.trainer_id_);
if (bst.use_hierarchical_allreduce_) { if (bst.use_hierarchical_allreduce_) {
std::vector<ncclUniqueId *> inter_nccl_ids; std::vector<ncclUniqueId *> inter_nccl_ids;
...@@ -175,12 +175,30 @@ class ParallelExecutorPrivate { ...@@ -175,12 +175,30 @@ class ParallelExecutorPrivate {
exter_nccl_ids.push_back(nccl_id); exter_nccl_ids.push_back(nccl_id);
} }
nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_ids, exter_nccl_ids, nccl_ctxs_->InitHierarchicalCtxs(
bst.num_trainers_, bst.trainer_id_, places_, inter_nccl_ids, exter_nccl_ids, bst.num_trainers_,
bst.hierarchical_allreduce_inter_nranks_, bst.trainer_id_, bst.hierarchical_allreduce_inter_nranks_,
bst.hierarchical_allreduce_exter_nranks_); bst.hierarchical_allreduce_exter_nranks_);
} }
} }
void InitOrGetNCCLCommunicator(framework::Scope *scope,
const BuildStrategy &bst) {
const std::string var_name = "NCCLCommunicator";
auto var = scope->FindVar(var_name);
if (var != nullptr) {
PADDLE_ENFORCE(var->IsInitialized(),
"if %s exists, it must be initialized", var_name);
VLOG(1) << "find " << var_name
<< " in scope, so use it and does not recreate!";
nccl_ctxs_ = var->GetMutable<platform::NCCLCommunicator>();
return;
}
VLOG(1) << "not find " << var_name << " in scope, so recreate it!";
nccl_ctxs_ = scope->Var(var_name)->GetMutable<platform::NCCLCommunicator>();
InitNCCLCtxs(scope, bst);
}
#endif #endif
BuildStrategy build_strategy_; BuildStrategy build_strategy_;
...@@ -190,7 +208,7 @@ class ParallelExecutorPrivate { ...@@ -190,7 +208,7 @@ class ParallelExecutorPrivate {
std::unique_ptr<details::SSAGraphExecutor> executor_; std::unique_ptr<details::SSAGraphExecutor> executor_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::MultiNCCLContextMap nccl_ctxs_; platform::NCCLCommunicator *nccl_ctxs_{nullptr};
#endif #endif
bool own_local_scope_; bool own_local_scope_;
bool use_cuda_; bool use_cuda_;
...@@ -281,27 +299,6 @@ bool ParallelExecutor::NeedCreateLocalExeScope() { ...@@ -281,27 +299,6 @@ bool ParallelExecutor::NeedCreateLocalExeScope() {
return executor && executor->NeedCreateLocalExeScope(); return executor && executor->NeedCreateLocalExeScope();
} }
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
/*
* When nccl inits nccl comm using ncclCommInitAll, it meets error when
* allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
* create a new nccl comm for sync_batch_norm_op. And these codes should be
* polished with a unified nccl management.
*/
platform::NCCLContextMap *ParallelExecutor::GetNCCLContextForSyncbatchNomrOp(
framework::Scope *scope) {
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
if (nccl_id_var != nullptr) {
return member_->nccl_ctxs_.DefaultFlatCtx();
}
if (dev_nccl_ctxs_.get() == nullptr) {
dev_nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
}
return dev_nccl_ctxs_.get();
}
#endif
ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const std::vector<std::string> &bcast_vars, const std::vector<std::string> &bcast_vars,
const std::string &loss_var_name, const std::string &loss_var_name,
...@@ -375,7 +372,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -375,7 +372,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if (member_->use_cuda_) { if (member_->use_cuda_) {
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
member_->InitNCCLCtxs(scope, build_strategy); member_->InitOrGetNCCLCommunicator(scope, build_strategy);
// Initialize device context's nccl comm, will be used by normal // Initialize device context's nccl comm, will be used by normal
// Operators like sync_batch_norm, and collective ops. // Operators like sync_batch_norm, and collective ops.
...@@ -384,7 +381,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -384,7 +381,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// NOTE: NCCL group-calls and non-group-calls can not use the same // NOTE: NCCL group-calls and non-group-calls can not use the same
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// same communicators. // same communicators.
auto *nccl_ctxs = GetNCCLContextForSyncbatchNomrOp(scope); auto *nccl_ctxs =
member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_);
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
...@@ -421,18 +419,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -421,18 +419,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name, graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1, {member_->local_scopes_[0]}, 1,
member_->use_cuda_, &member_->nccl_ctxs_); member_->use_cuda_, member_->nccl_ctxs_);
for (size_t i = 1; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
graphs[i] = graphs[i] =
build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name, build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, {member_->local_scopes_[i]}, 1,
member_->use_cuda_, &member_->nccl_ctxs_); member_->use_cuda_, member_->nccl_ctxs_);
async_graphs[i] = graphs[i]; async_graphs[i] = graphs[i];
} }
} else { } else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name, graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->local_scopes_, member_->nranks_,
member_->use_cuda_, &member_->nccl_ctxs_); member_->use_cuda_, member_->nccl_ctxs_);
} }
#else #else
if (build_strategy.async_mode_) { if (build_strategy.async_mode_) {
...@@ -565,7 +563,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -565,7 +563,7 @@ void ParallelExecutor::BCastParamsToDevices(
PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(), PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(),
"variables' buffer size to bcast NOT equal to places"); "variables' buffer size to bcast NOT equal to places");
{ {
auto *nccl_ctxs = member_->nccl_ctxs_.DefaultFlatCtx(); auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx();
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &nccl_ctx = nccl_ctxs->at(member_->places_[i]); auto &nccl_ctx = nccl_ctxs->at(member_->places_[i]);
......
...@@ -87,13 +87,6 @@ class ParallelExecutor { ...@@ -87,13 +87,6 @@ class ParallelExecutor {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_; std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
// used for compatible with syncbatch norm op
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs_;
platform::NCCLContextMap *GetNCCLContextForSyncbatchNomrOp(
framework::Scope *scope);
#endif
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/var_type_traits.h"
#include <unordered_map>
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -22,6 +23,7 @@ ...@@ -22,6 +23,7 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include <cudnn.h> #include <cudnn.h>
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
......
...@@ -36,6 +36,7 @@ namespace platform { ...@@ -36,6 +36,7 @@ namespace platform {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
class Communicator; class Communicator;
class NCCLCommunicator;
#endif #endif
#endif #endif
} // namespace platform } // namespace platform
...@@ -140,7 +141,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -140,7 +141,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder, std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
ncclUniqueId, platform::Communicator, ncclUniqueId, platform::Communicator, platform::NCCLCommunicator,
#endif #endif
operators::CudnnRNNCache, operators::CudnnRNNCache,
#endif #endif
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#ifndef _WIN32 #ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h" #include "paddle/fluid/operators/cudnn_rnn_cache.h"
......
...@@ -176,10 +176,10 @@ inline std::string GetHierarchicalInterNCCLVarName(size_t pos) { ...@@ -176,10 +176,10 @@ inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
static_cast<int>(pos)); static_cast<int>(pos));
} }
class MultiNCCLContextMap { class NCCLCommunicator {
public: public:
MultiNCCLContextMap() {} NCCLCommunicator() {}
virtual ~MultiNCCLContextMap() {} virtual ~NCCLCommunicator() {}
NCCLContextMap *DefaultFlatCtx() const { NCCLContextMap *DefaultFlatCtx() const {
if (flat_ctxs_.size() == 0) { if (flat_ctxs_.size() == 0) {
...@@ -206,6 +206,25 @@ class MultiNCCLContextMap { ...@@ -206,6 +206,25 @@ class MultiNCCLContextMap {
return GetHierarchicalInterCtx(run_order); return GetHierarchicalInterCtx(run_order);
} }
/*
*When nccl inits nccl comm using ncclCommInitAll, it meets error when
*allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
*create a new nccl comm for sync_batch_norm_op. And these codes should be
*polished with a unified nccl management.
*/
NCCLContextMap *GetSyncBatchNormCtx(
framework::Scope *scope, const std::vector<platform::Place> &places) {
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
if (nccl_id_var != nullptr) {
return DefaultFlatCtx();
}
if (sync_batch_norm_ctx_.get() == nullptr) {
sync_batch_norm_ctx_.reset(new NCCLContextMap(places));
}
return sync_batch_norm_ctx_.get();
}
void InitFlatCtxs(const std::vector<platform::Place> &places, void InitFlatCtxs(const std::vector<platform::Place> &places,
const std::vector<ncclUniqueId *> &nccl_ids, const std::vector<ncclUniqueId *> &nccl_ids,
size_t trainers_num, size_t trainer_id) { size_t trainers_num, size_t trainer_id) {
...@@ -290,6 +309,9 @@ class MultiNCCLContextMap { ...@@ -290,6 +309,9 @@ class MultiNCCLContextMap {
// And h_exter_ctxs_ can support multi comm too. // And h_exter_ctxs_ can support multi comm too.
std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_; std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_;
std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_; std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_;
// just used for sync_batch_norm op.
std::unique_ptr<NCCLContextMap> sync_batch_norm_ctx_;
}; };
} // namespace platform } // namespace platform
......
...@@ -167,6 +167,15 @@ class TestDistRunnerBase(object): ...@@ -167,6 +167,15 @@ class TestDistRunnerBase(object):
build_strategy=build_stra, build_strategy=build_stra,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
if args.use_cuda and args.update_method == "nccl2":
# it just for test share_vars_from feature.
test_exe = fluid.ParallelExecutor(
use_cuda=True,
loss_name=avg_cost.name,
build_strategy=build_stra,
main_program=test_program,
share_vars_from=binary._executor)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.values() var for var in trainer_prog.global_block().vars.values()
if var.is_data if var.is_data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册