未验证 提交 6a706e63 编写于 作者: R Ruibiao Chen 提交者: GitHub

Convert GradMergeAllReduceOpHandle in GraphToBlock (#46544)

* Convert GradMergeAllReduceOpHandle in GraphToBlock

* Set FLAGS_CONVERT_GRAPH_TO_PROGRAM to False
上级 3fc4fa29
...@@ -506,25 +506,27 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { ...@@ -506,25 +506,27 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
return desc; return desc;
} }
static void ReplaceAllReduceOp(const Node &node, void ReplaceAllReduceOp(const Node &node,
proto::BlockDesc *block, proto::BlockDesc *block,
std::vector<OpDesc> *ops) { std::vector<OpDesc> *ops) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ops->emplace_back(); bool is_fused = (node.Name() == "fused_all_reduce");
auto &desc1 = ops->back(); details::OpHandleBase &op_handle =
std::string name = "fake_coalesce_" + std::to_string(ops->size());
desc1.SetType("check_memory_continue");
ops->emplace_back();
auto &desc2 = ops->back();
desc2.SetType("c_allreduce_sum");
if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_handler =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>(); const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
// set inputs std::string all_reduce_var_name;
auto in_var_handles = op_handler.Inputs(); // If fused, add check_memory_continue OP to fuse inputs
if (is_fused) {
all_reduce_var_name = "fake_coalesce_" + std::to_string(ops->size());
proto::VarDesc var_desc;
var_desc.set_name(all_reduce_var_name);
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
block->mutable_vars()->Add()->CopyFrom(var_desc);
VLOG(4) << "add variable for check_memory_continue: "
<< all_reduce_var_name;
// get inputs of check_memory_continue
auto in_var_handles = op_handle.Inputs();
std::vector<std::string> in_names; std::vector<std::string> in_names;
for (const auto &in : in_var_handles) { for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) { if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
...@@ -532,47 +534,47 @@ static void ReplaceAllReduceOp(const Node &node, ...@@ -532,47 +534,47 @@ static void ReplaceAllReduceOp(const Node &node,
} }
in_names.emplace_back(in->Name()); in_names.emplace_back(in->Name());
} }
desc1.SetInput("X", in_names);
proto::VarDesc var_desc; ops->emplace_back();
var_desc.set_name(name); OpDesc &fuse_op_desc = ops->back();
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR); fuse_op_desc.SetType("check_memory_continue");
block->mutable_vars()->Add()->CopyFrom(var_desc); fuse_op_desc.SetInput("X", in_names);
desc1.SetOutput("Out", {name}); fuse_op_desc.SetOutput("Out", {all_reduce_var_name});
desc1.SetOutput("XOut", in_names); fuse_op_desc.SetOutput("XOut", in_names);
VLOG(4) << "add variable for check_memory_continue: " << name; fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
desc2.SetInput("X", {name}); } else {
// set outputs all_reduce_var_name = op_handle.Inputs()[0]->Name();
auto out_var_handles = op_handler.Outputs();
std::vector<std::string> out_names;
for (const auto &out : out_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
continue;
}
out_names.emplace_back(out->Name());
} }
desc2.SetOutput("Out", {name});
// add c_allreduce_sum OP
ops->emplace_back();
OpDesc &all_reduce_op_desc = ops->back();
all_reduce_op_desc.SetType("c_allreduce_sum");
all_reduce_op_desc.SetInput("X", {all_reduce_var_name});
all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId( int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handler)->GetComm()); dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
desc2.SetAttr("ring_id", ring_id); all_reduce_op_desc.SetAttr("ring_id", ring_id);
desc2.SetAttr("use_calc_stream", true); all_reduce_op_desc.SetAttr("use_calc_stream", true);
all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
// handle grad merge // handle grad merge
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) { if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)) {
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
auto cond_name = const std::string cond_name =
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler) dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)
->GradMergeCondName(); ->GradMergeCondName();
desc2.SetInput("Cond", {cond_name}); all_reduce_op_desc.SetInput("Cond", {cond_name});
} } else if (dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)) {
VLOG(4) << "GradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
const std::string cond_name =
dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)
->GradMergeCondName();
all_reduce_op_desc.SetInput("Cond", {cond_name});
} }
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented " platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
...@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes, ...@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
for (Node *n : nodes) { for (Node *n : nodes) {
// if node is not Op, skip // if node is not Op, skip
if (!n->IsOp()) continue; if (!n->IsOp()) continue;
// create fill_constant op // create fill_constant op
if (n->Name() == "scale_loss_grad") { if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant"; VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back(); ops->emplace_back();
auto &desc = ops->back(); auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc); ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Name() == "fused_all_reduce") { } else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") {
VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum"; VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum";
ReplaceAllReduceOp(*n, block, ops); ReplaceAllReduceOp(*n, block, ops);
VLOG(4) << n->ToString(); VLOG(4) << n->ToString();
} else if (n->Op()) { } else if (n->Op()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册