未验证 提交 6a706e63 编写于 作者: R Ruibiao Chen 提交者: GitHub

Convert GradMergeAllReduceOpHandle in GraphToBlock (#46544)

* Convert GradMergeAllReduceOpHandle in GraphToBlock

* Set FLAGS_CONVERT_GRAPH_TO_PROGRAM to False
上级 3fc4fa29
...@@ -506,25 +506,27 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { ...@@ -506,25 +506,27 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
return desc; return desc;
} }
static void ReplaceAllReduceOp(const Node &node, void ReplaceAllReduceOp(const Node &node,
proto::BlockDesc *block, proto::BlockDesc *block,
std::vector<OpDesc> *ops) { std::vector<OpDesc> *ops) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ops->emplace_back(); bool is_fused = (node.Name() == "fused_all_reduce");
auto &desc1 = ops->back(); details::OpHandleBase &op_handle =
std::string name = "fake_coalesce_" + std::to_string(ops->size()); const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
desc1.SetType("check_memory_continue");
std::string all_reduce_var_name;
ops->emplace_back(); // If fused, add check_memory_continue OP to fuse inputs
auto &desc2 = ops->back(); if (is_fused) {
desc2.SetType("c_allreduce_sum"); all_reduce_var_name = "fake_coalesce_" + std::to_string(ops->size());
proto::VarDesc var_desc;
if (node.IsWrappedBy<details::OpHandleBase>()) { var_desc.set_name(all_reduce_var_name);
details::OpHandleBase &op_handler = var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>(); block->mutable_vars()->Add()->CopyFrom(var_desc);
VLOG(4) << "add variable for check_memory_continue: "
<< all_reduce_var_name;
// set inputs // get inputs of check_memory_continue
auto in_var_handles = op_handler.Inputs(); auto in_var_handles = op_handle.Inputs();
std::vector<std::string> in_names; std::vector<std::string> in_names;
for (const auto &in : in_var_handles) { for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) { if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
...@@ -532,47 +534,47 @@ static void ReplaceAllReduceOp(const Node &node, ...@@ -532,47 +534,47 @@ static void ReplaceAllReduceOp(const Node &node,
} }
in_names.emplace_back(in->Name()); in_names.emplace_back(in->Name());
} }
desc1.SetInput("X", in_names);
proto::VarDesc var_desc;
var_desc.set_name(name);
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
block->mutable_vars()->Add()->CopyFrom(var_desc);
desc1.SetOutput("Out", {name});
desc1.SetOutput("XOut", in_names);
VLOG(4) << "add variable for check_memory_continue: " << name;
desc2.SetInput("X", {name});
// set outputs
auto out_var_handles = op_handler.Outputs();
std::vector<std::string> out_names;
for (const auto &out : out_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
continue;
}
out_names.emplace_back(out->Name());
}
desc2.SetOutput("Out", {name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handler)->GetComm());
desc2.SetAttr("ring_id", ring_id);
desc2.SetAttr("use_calc_stream", true);
// handle grad merge ops->emplace_back();
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) { OpDesc &fuse_op_desc = ops->back();
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; fuse_op_desc.SetType("check_memory_continue");
auto cond_name = fuse_op_desc.SetInput("X", in_names);
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler) fuse_op_desc.SetOutput("Out", {all_reduce_var_name});
->GradMergeCondName(); fuse_op_desc.SetOutput("XOut", in_names);
desc2.SetInput("Cond", {cond_name}); fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
} (static_cast<int>(OpRole::kBackward)));
} else {
all_reduce_var_name = op_handle.Inputs()[0]->Name();
} }
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), // add c_allreduce_sum OP
(static_cast<int>(OpRole::kBackward))); ops->emplace_back();
desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), OpDesc &all_reduce_op_desc = ops->back();
(static_cast<int>(OpRole::kBackward))); all_reduce_op_desc.SetType("c_allreduce_sum");
all_reduce_op_desc.SetInput("X", {all_reduce_var_name});
all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
all_reduce_op_desc.SetAttr("ring_id", ring_id);
all_reduce_op_desc.SetAttr("use_calc_stream", true);
all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
// handle grad merge
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)) {
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
const std::string cond_name =
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)
->GradMergeCondName();
all_reduce_op_desc.SetInput("Cond", {cond_name});
} else if (dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)) {
VLOG(4) << "GradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
const std::string cond_name =
dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)
->GradMergeCondName();
all_reduce_op_desc.SetInput("Cond", {cond_name});
}
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented " platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
...@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes, ...@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
for (Node *n : nodes) { for (Node *n : nodes) {
// if node is not Op, skip // if node is not Op, skip
if (!n->IsOp()) continue; if (!n->IsOp()) continue;
// create fill_constant op // create fill_constant op
if (n->Name() == "scale_loss_grad") { if (n->Name() == "scale_loss_grad") {
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant"; VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
ops->emplace_back(); ops->emplace_back();
auto &desc = ops->back(); auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc); ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Name() == "fused_all_reduce") { } else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") {
VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum"; VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum";
ReplaceAllReduceOp(*n, block, ops); ReplaceAllReduceOp(*n, block, ops);
VLOG(4) << n->ToString(); VLOG(4) << n->ToString();
} else if (n->Op()) { } else if (n->Op()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册