未验证 提交 0d561ef4 编写于 作者: G gongweibao 提交者: GitHub

fix 2dconn test=develop (#17681)

上级 ccf9e232
......@@ -157,9 +157,14 @@ class ParallelExecutorPrivate {
bst.trainer_id_);
if (bst.use_hierarchical_allreduce_) {
std::string var_name = platform::GetHierarchicalInterNCCLVarName();
std::vector<ncclUniqueId *> inter_nccl_ids;
for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
std::string var_name = platform::GetHierarchicalInterNCCLVarName(i);
auto nccl_id_var = scope->FindVar(var_name);
PADDLE_ENFORCE(nccl_id_var, "can't find %s nccl_id_var", var_name);
auto inter_nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
inter_nccl_ids.push_back(inter_nccl_id);
}
std::vector<ncclUniqueId *> exter_nccl_ids;
for (int i = 0; i < static_cast<int>(bst.nccl_comm_num_); i++) {
......@@ -169,7 +174,8 @@ class ParallelExecutorPrivate {
auto nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
exter_nccl_ids.push_back(nccl_id);
}
nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_id, exter_nccl_ids,
nccl_ctxs_.InitHierarchicalCtxs(places_, inter_nccl_ids, exter_nccl_ids,
bst.num_trainers_, bst.trainer_id_,
bst.hierarchical_allreduce_inter_nranks_,
bst.hierarchical_allreduce_exter_nranks_);
......
......@@ -124,9 +124,12 @@ class GenNCCLIdOp : public framework::OperatorBase {
ss << trainers[i];
}
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
std::string nccl_var_name = platform::GetHierarchicalInterNCCLVarName();
for (int i = 0; i < nccl_comm_num; i++) {
std::string nccl_var_name =
platform::GetHierarchicalInterNCCLVarName(i);
GenerateAndSend(&local_scope, dev_ctx, nccl_var_name, inter_endpoints);
}
}
// hierarchical exter ncclid
if (exter_trainer_id == 0) {
......@@ -208,13 +211,15 @@ class GenNCCLIdOp : public framework::OperatorBase {
if (use_hierarchical_allreduce) {
if (inter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "trainer_id:" << trainer_id
<< ", inter_trainer_id:" << inter_trainer_id
<< " start getting nccl id from inter_trainer 0";
<< " start getting nccl id from inter_trainer:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
if (exter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
......
......@@ -171,8 +171,9 @@ inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME,
static_cast<int>(pos));
}
inline std::string GetHierarchicalInterNCCLVarName() {
return string::Sprintf("Hierarchical_inter_%s", NCCL_ID_VARNAME);
inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME,
static_cast<int>(pos));
}
class MultiNCCLContextMap {
......@@ -224,8 +225,8 @@ class MultiNCCLContextMap {
}
void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
ncclUniqueId *inter_nccl_id,
const std::vector<ncclUniqueId *> &exter_nccl_id,
const std::vector<ncclUniqueId *> &inter_nccl_ids,
const std::vector<ncclUniqueId *> &exter_nccl_ids,
size_t trainers_num, size_t trainer_id,
size_t inter_trainers_num,
size_t exter_trainers_num) {
......@@ -238,11 +239,14 @@ class MultiNCCLContextMap {
inter_trainers_num);
int inter_trainer_id = trainer_id % inter_trainers_num;
VLOG(1) << "init inter_trainer_id:" << inter_trainer_id;
auto local = new NCCLContextMap(places, inter_nccl_id, inter_trainers_num,
inter_trainer_id);
for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
<< ", comm no:" << i;
auto local = new NCCLContextMap(places, inter_nccl_ids[i],
inter_trainers_num, inter_trainer_id);
h_inter_ctxs_.emplace_back(local);
}
int exter_trainer_id = -1;
if (trainer_id % inter_trainers_num == 0) {
......@@ -250,8 +254,8 @@ class MultiNCCLContextMap {
}
if (exter_trainer_id >= 0) {
for (size_t i = 0; i < exter_nccl_id.size(); i++) {
auto ex = new NCCLContextMap(places, exter_nccl_id[i],
for (size_t i = 0; i < exter_nccl_ids.size(); i++) {
auto ex = new NCCLContextMap(places, exter_nccl_ids[i],
exter_trainers_num, exter_trainer_id);
VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
<< ", comm no:" << i;
......
......@@ -278,11 +278,11 @@ class DistributeTranspiler(object):
type=core.VarDesc.VarType.RAW)
if self.config.use_hierarchical_allreduce:
for i in range(0, self.config.nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_inter_NCCLID",
name="Hierarchical_inter_NCCLID_{}".format(i),
persistable=True,
type=core.VarDesc.VarType.RAW)
for i in range(0, self.config.nccl_comm_num):
startup_program.global_block().create_var(
name="Hierarchical_exter_NCCLID_{}".format(i),
persistable=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册