未验证 提交 6a706e63 编写于 作者: R Ruibiao Chen 提交者: GitHub

Convert GradMergeAllReduceOpHandle in GraphToBlock (#46544)

* Convert GradMergeAllReduceOpHandle in GraphToBlock

* Set FLAGS_CONVERT_GRAPH_TO_PROGRAM to False
上级 3fc4fa29
......@@ -506,25 +506,27 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
return desc;
}
static void ReplaceAllReduceOp(const Node &node,
proto::BlockDesc *block,
std::vector<OpDesc> *ops) {
void ReplaceAllReduceOp(const Node &node,
proto::BlockDesc *block,
std::vector<OpDesc> *ops) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ops->emplace_back();
auto &desc1 = ops->back();
std::string name = "fake_coalesce_" + std::to_string(ops->size());
desc1.SetType("check_memory_continue");
ops->emplace_back();
auto &desc2 = ops->back();
desc2.SetType("c_allreduce_sum");
if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_handler =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
bool is_fused = (node.Name() == "fused_all_reduce");
details::OpHandleBase &op_handle =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
std::string all_reduce_var_name;
// If fused, add check_memory_continue OP to fuse inputs
if (is_fused) {
all_reduce_var_name = "fake_coalesce_" + std::to_string(ops->size());
proto::VarDesc var_desc;
var_desc.set_name(all_reduce_var_name);
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
block->mutable_vars()->Add()->CopyFrom(var_desc);
VLOG(4) << "add variable for check_memory_continue: "
<< all_reduce_var_name;
// set inputs
auto in_var_handles = op_handler.Inputs();
// get inputs of check_memory_continue
auto in_var_handles = op_handle.Inputs();
std::vector<std::string> in_names;
for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
......@@ -532,47 +534,47 @@ static void ReplaceAllReduceOp(const Node &node,
}
in_names.emplace_back(in->Name());
}
desc1.SetInput("X", in_names);
proto::VarDesc var_desc;
var_desc.set_name(name);
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
block->mutable_vars()->Add()->CopyFrom(var_desc);
desc1.SetOutput("Out", {name});
desc1.SetOutput("XOut", in_names);
VLOG(4) << "add variable for check_memory_continue: " << name;
desc2.SetInput("X", {name});
// set outputs
auto out_var_handles = op_handler.Outputs();
std::vector<std::string> out_names;
for (const auto &out : out_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
continue;
}
out_names.emplace_back(out->Name());
}
desc2.SetOutput("Out", {name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handler)->GetComm());
desc2.SetAttr("ring_id", ring_id);
desc2.SetAttr("use_calc_stream", true);
// handle grad merge
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) {
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
auto cond_name =
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)
->GradMergeCondName();
desc2.SetInput("Cond", {cond_name});
}
ops->emplace_back();
OpDesc &fuse_op_desc = ops->back();
fuse_op_desc.SetType("check_memory_continue");
fuse_op_desc.SetInput("X", in_names);
fuse_op_desc.SetOutput("Out", {all_reduce_var_name});
fuse_op_desc.SetOutput("XOut", in_names);
fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
} else {
all_reduce_var_name = op_handle.Inputs()[0]->Name();
}
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
// add c_allreduce_sum OP
ops->emplace_back();
OpDesc &all_reduce_op_desc = ops->back();
all_reduce_op_desc.SetType("c_allreduce_sum");
all_reduce_op_desc.SetInput("X", {all_reduce_var_name});
all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
all_reduce_op_desc.SetAttr("ring_id", ring_id);
all_reduce_op_desc.SetAttr("use_calc_stream", true);
all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
// handle grad merge
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)) {
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
const std::string cond_name =
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)
->GradMergeCondName();
all_reduce_op_desc.SetInput("Cond", {cond_name});
} else if (dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)) {
VLOG(4) << "GradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
const std::string cond_name =
dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)
->GradMergeCondName();
all_reduce_op_desc.SetInput("Cond", {cond_name});
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
......@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
for (Node *n : nodes) {
// if node is not Op, skip
if (!n->IsOp()) continue;
// create fill_constant op
if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Name() == "fused_all_reduce") {
VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum";
} else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") {
VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum";
ReplaceAllReduceOp(*n, block, ops);
VLOG(4) << n->ToString();
} else if (n->Op()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册