未验证 提交 0d561ef4 编写于 作者: G gongweibao 提交者: GitHub

fix 2dconn test=develop (#17681)

上级 ccf9e232
...@@ -157,9 +157,14 @@ class ParallelExecutorPrivate { ...@@ -157,9 +157,14 @@ class ParallelExecutorPrivate {
bst.trainer_id_); bst.trainer_id_);
if (bst.use_hierarchical_allreduce_) { if (bst.use_hierarchical_allreduce_) {
std::string var_name = platform::GetHierarchicalInterNCCLVarName(); std::vector<ncclUniqueId *> inter_nccl_ids;
auto nccl_id_var = scope->FindVar(var_name); for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
auto inter_nccl_id = nccl_id_var->GetMutable<ncclUniqueId>(); std::string var_name = platform::GetHierarchicalInterNCCLVarName(i);
auto nccl_id_var = scope->FindVar(var_name);
PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name);
auto inter_nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
inter_nccl_ids.push_back(inter_nccl_id);
}
std::vector<ncclUniqueId *> exter_nccl_ids; std::vector<ncclUniqueId *> exter_nccl_ids;
for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) { for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
...@@ -169,7 +174,8 @@ class ParallelExecutorPrivate { ...@@ -169,7 +174,8 @@ class ParallelExecutorPrivate {
auto nccl_id = nccl_id_var->GetMutable<ncclUniqueId>(); auto nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
exter_nccl_ids.push_back(nccl_id); exter_nccl_ids.push_back(nccl_id);
} }
nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_id, exter_nccl_ids,
nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_ids, exter_nccl_ids,
bst.num_trainers_, bst.trainer_id_, bst.num_trainers_, bst.trainer_id_,
bst.hierarchical_allreduce_inter_nranks_, bst.hierarchical_allreduce_inter_nranks_,
bst.hierarchical_allreduce_exter_nranks_); bst.hierarchical_allreduce_exter_nranks_);
......
...@@ -124,8 +124,11 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -124,8 +124,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
ss << trainers[i]; ss << trainers[i];
} }
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str(); VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
std::string nccl_var_name = platform::GetHierarchicalInterNCCLVarName(); for (int i = 0; i < nccl_comm_num; i++) {
GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, inter_endpoints); std::string nccl_var_name =
platform::GetHierarchicalInterNCCLVarName(i);
GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, inter_endpoints);
}
} }
// hierarchical exter ncclid // hierarchical exter ncclid
...@@ -208,12 +211,14 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -208,12 +211,14 @@ class GenNCCLIdOp : public framework::OperatorBase {
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
if (inter_trainer_id > 0) { if (inter_trainer_id > 0) {
rpc_service->SetCond(distributed::kRequestSend); for (int i = 0; i < nccl_comm_num; i++) {
VLOG(3) << "trainer_id:" << trainer_id rpc_service->SetCond(distributed::kRequestSend);
<< ", inter_trainer_id:" << inter_trainer_id VLOG(3) << "trainer_id:" << trainer_id
<< " start getting nccl id from inter_trainer 0"; << ", inter_trainer_id:" << inter_trainer_id
rpc_service->WaitBarrier(distributed::kRequestSend); << " start getting nccl id from inter_trainer:" << i;
rpc_service->ResetBarrierCounter(); rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
} }
if (exter_trainer_id > 0) { if (exter_trainer_id > 0) {
......
...@@ -171,8 +171,9 @@ inline std::string GetHierarchicalExterNCCLVarName(size_t pos) { ...@@ -171,8 +171,9 @@ inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME, return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME,
static_cast<int>(pos)); static_cast<int>(pos));
} }
inline std::string GetHierarchicalInterNCCLVarName() { inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
return string::Sprintf("Hierarchical_inter_%s", NCCL_ID_VARNAME); return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME,
static_cast<int>(pos));
} }
class MultiNCCLContextMap { class MultiNCCLContextMap {
...@@ -224,8 +225,8 @@ class MultiNCCLContextMap { ...@@ -224,8 +225,8 @@ class MultiNCCLContextMap {
} }
void InitHierarchicalCtxs(const std::vector<platform::Place> &places, void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
ncclUniqueId *inter_nccl_id, const std::vector<ncclUniqueId *> &inter_nccl_ids,
const std::vector<ncclUniqueId *> &exter_nccl_id, const std::vector<ncclUniqueId *> &exter_nccl_ids,
size_t trainers_num, size_t trainer_id, size_t trainers_num, size_t trainer_id,
size_t inter_trainers_num, size_t inter_trainers_num,
size_t exter_trainers_num) { size_t exter_trainers_num) {
...@@ -238,11 +239,14 @@ class MultiNCCLContextMap { ...@@ -238,11 +239,14 @@ class MultiNCCLContextMap {
inter_trainers_num); inter_trainers_num);
int inter_trainer_id = trainer_id % inter_trainers_num; int inter_trainer_id = trainer_id % inter_trainers_num;
VLOG(1) << "init inter_trainer_id:" << inter_trainer_id; for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
auto local = new NCCLContextMap(places, inter_nccl_id, inter_trainers_num, VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
inter_trainer_id); << ", comm no:" << i;
auto local = new NCCLContextMap(places, inter_nccl_ids[i],
inter_trainers_num, inter_trainer_id);
h_inter_ctxs_.emplace_back(local); h_inter_ctxs_.emplace_back(local);
}
int exter_trainer_id = -1; int exter_trainer_id = -1;
if (trainer_id % inter_trainers_num == 0) { if (trainer_id % inter_trainers_num == 0) {
...@@ -250,8 +254,8 @@ class MultiNCCLContextMap { ...@@ -250,8 +254,8 @@ class MultiNCCLContextMap {
} }
if (exter_trainer_id >= 0) { if (exter_trainer_id >= 0) {
for (size_t i = 0; i < exter_nccl_id.size(); i++) { for (size_t i = 0; i < exter_nccl_ids.size(); i++) {
auto ex = new NCCLContextMap(places, exter_nccl_id[i], auto ex = new NCCLContextMap(places, exter_nccl_ids[i],
exter_trainers_num, exter_trainer_id); exter_trainers_num, exter_trainer_id);
VLOG(1) << "init exter_trainer_id:" << exter_trainer_id VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
<< ", comm no:" << i; << ", comm no:" << i;
......
...@@ -278,11 +278,11 @@ class DistributeTranspiler(object): ...@@ -278,11 +278,11 @@ class DistributeTranspiler(object):
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
if self.config.use_hierarchical_allreduce: if self.config.use_hierarchical_allreduce:
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID",
persistable=True,
type=core.VarDesc.VarType.RAW)
for i in range(0, self.config.nccl_comm_num): for i in range(0, self.config.nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_program.global_block().create_var( startup_program.global_block().create_var(
name="Hierarchical_exter_NCCLID_{}".format(i), name="Hierarchical_exter_NCCLID_{}".format(i),
persistable=True, persistable=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册