From 0d561ef4426cac4c177959e766ed926b6ea77866 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 29 May 2019 14:07:21 +0800 Subject: [PATCH] fix 2dconn test=develop (#17681) --- paddle/fluid/framework/parallel_executor.cc | 14 +++++++---- .../distributed_ops/gen_nccl_id_op.cc | 21 +++++++++------- paddle/fluid/platform/nccl_helper.h | 24 +++++++++++-------- .../fluid/transpiler/distribute_transpiler.py | 8 +++---- 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 43b8645fb..0667748c2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -157,9 +157,14 @@ class ParallelExecutorPrivate { bst.trainer_id_); if (bst.use_hierarchical_allreduce_) { - std::string var_name = platform::GetHierarchicalInterNCCLVarName(); - auto nccl_id_var = scope->FindVar(var_name); - auto inter_nccl_id = nccl_id_var->GetMutable(); + std::vector inter_nccl_ids; + for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { + std::string var_name = platform::GetHierarchicalInterNCCLVarName(i); + auto nccl_id_var = scope->FindVar(var_name); + PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name); + auto inter_nccl_id = nccl_id_var->GetMutable(); + inter_nccl_ids.push_back(inter_nccl_id); + } std::vector exter_nccl_ids; for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { @@ -169,7 +174,8 @@ class ParallelExecutorPrivate { auto nccl_id = nccl_id_var->GetMutable(); exter_nccl_ids.push_back(nccl_id); } - nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_id, exter_nccl_ids, + + nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_ids, exter_nccl_ids, bst.num_trainers_, bst.trainer_id_, bst.hierarchical_allreduce_inter_nranks_, bst.hierarchical_allreduce_exter_nranks_); diff --git a/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc index 2446550ff..c33842c06 100644 --- a/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc @@ -124,8 +124,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ss << trainers[i]; } VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str(); - std::string nccl_var_name = platform::GetHierarchicalInterNCCLVarName(); - GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, inter_endpoints); + for (int i = 0; i < nccl_comm_num; i++) { + std::string nccl_var_name = + platform::GetHierarchicalInterNCCLVarName(i); + GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, inter_endpoints); + } } // hierarchical exter ncclid @@ -208,12 +211,14 @@ class GenNCCLIdOp : public framework::OperatorBase { if (use_hierarchical_allreduce) { if (inter_trainer_id > 0) { - rpc_service->SetCond(distributed::kRequestSend); - VLOG(3) << "trainer_id:" << trainer_id - << ", inter_trainer_id:" << inter_trainer_id - << " start getting nccl id from inter_trainer 0"; - rpc_service->WaitBarrier(distributed::kRequestSend); - rpc_service->ResetBarrierCounter(); + for (int i = 0; i < nccl_comm_num; i++) { + rpc_service->SetCond(distributed::kRequestSend); + VLOG(3) << "trainer_id:" << trainer_id + << ", inter_trainer_id:" << inter_trainer_id + << " start getting nccl id from inter_trainer:" << i; + rpc_service->WaitBarrier(distributed::kRequestSend); + rpc_service->ResetBarrierCounter(); + } } if (exter_trainer_id > 0) { diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 9eb617281..18bc17f5c 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -171,8 +171,9 @@ inline std::string GetHierarchicalExterNCCLVarName(size_t pos) { return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME, static_cast(pos)); } -inline std::string GetHierarchicalInterNCCLVarName() { - return string::Sprintf("Hierarchical_inter_%s", NCCL_ID_VARNAME); +inline std::string GetHierarchicalInterNCCLVarName(size_t pos) { + return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME, + static_cast(pos)); } class MultiNCCLContextMap { @@ -224,8 +225,8 @@ class MultiNCCLContextMap { } void InitHierarchicalCtxs(const std::vector &places, - ncclUniqueId *inter_nccl_id, - const std::vector &exter_nccl_id, + const std::vector &inter_nccl_ids, + const std::vector &exter_nccl_ids, size_t trainers_num, size_t trainer_id, size_t inter_trainers_num, size_t exter_trainers_num) { @@ -238,11 +239,14 @@ class MultiNCCLContextMap { inter_trainers_num); int inter_trainer_id = trainer_id % inter_trainers_num; - VLOG(1) << "init inter_trainer_id:" << inter_trainer_id; - auto local = new NCCLContextMap(places, inter_nccl_id, inter_trainers_num, - inter_trainer_id); + for (size_t i = 0; i < inter_nccl_ids.size(); i++) { + VLOG(1) << "init inter_trainer_id:" << inter_trainer_id + << ", comm no:" << i; + auto local = new NCCLContextMap(places, inter_nccl_ids[i], + inter_trainers_num, inter_trainer_id); - h_inter_ctxs_.emplace_back(local); + h_inter_ctxs_.emplace_back(local); + } int exter_trainer_id = -1; if (trainer_id % inter_trainers_num == 0) { @@ -250,8 +254,8 @@ class MultiNCCLContextMap { } if (exter_trainer_id >= 0) { - for (size_t i = 0; i < exter_nccl_id.size(); i++) { - auto ex = new NCCLContextMap(places, exter_nccl_id[i], + for (size_t i = 0; i < exter_nccl_ids.size(); i++) { + auto ex = new NCCLContextMap(places, exter_nccl_ids[i], exter_trainers_num, exter_trainer_id); VLOG(1) << "init exter_trainer_id:" << exter_trainer_id << ", comm no:" << i; diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d7dde1f1c..1f08d0328 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -278,11 +278,11 @@ class DistributeTranspiler(object): type=core.VarDesc.VarType.RAW) if self.config.use_hierarchical_allreduce: - startup_program.global_block().create_var( - name="Hierarchical_inter_NCCLID", - persistable=True, - type=core.VarDesc.VarType.RAW) for i in range(0, self.config.nccl_comm_num): + startup_program.global_block().create_var( + name="Hierarchical_inter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) startup_program.global_block().create_var( name="Hierarchical_exter_NCCLID_{}".format(i), persistable=True, -- GitLab