diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 41dabfbd36bc61b4e939dca3d93139a82586308a..37dbfe27582f7cc6309a20f521344c985658eca8 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -506,25 +506,27 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { return desc; } -static void ReplaceAllReduceOp(const Node &node, - proto::BlockDesc *block, - std::vector *ops) { +void ReplaceAllReduceOp(const Node &node, + proto::BlockDesc *block, + std::vector *ops) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - ops->emplace_back(); - auto &desc1 = ops->back(); - std::string name = "fake_coalesce_" + std::to_string(ops->size()); - desc1.SetType("check_memory_continue"); - - ops->emplace_back(); - auto &desc2 = ops->back(); - desc2.SetType("c_allreduce_sum"); - - if (node.IsWrappedBy()) { - details::OpHandleBase &op_handler = - const_cast(&node)->Wrapper(); + bool is_fused = (node.Name() == "fused_all_reduce"); + details::OpHandleBase &op_handle = + const_cast(&node)->Wrapper(); + + std::string all_reduce_var_name; + // If fused, add check_memory_continue OP to fuse inputs + if (is_fused) { + all_reduce_var_name = "fake_coalesce_" + std::to_string(ops->size()); + proto::VarDesc var_desc; + var_desc.set_name(all_reduce_var_name); + var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR); + block->mutable_vars()->Add()->CopyFrom(var_desc); + VLOG(4) << "add variable for check_memory_continue: " + << all_reduce_var_name; - // set inputs - auto in_var_handles = op_handler.Inputs(); + // get inputs of check_memory_continue + auto in_var_handles = op_handle.Inputs(); std::vector in_names; for (const auto &in : in_var_handles) { if (dynamic_cast(in) != nullptr) { @@ -532,47 +534,47 @@ static void ReplaceAllReduceOp(const Node &node, } in_names.emplace_back(in->Name()); } - desc1.SetInput("X", in_names); - - proto::VarDesc var_desc; - var_desc.set_name(name); - var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR); - block->mutable_vars()->Add()->CopyFrom(var_desc); - desc1.SetOutput("Out", {name}); - desc1.SetOutput("XOut", in_names); - VLOG(4) << "add variable for check_memory_continue: " << name; - - desc2.SetInput("X", {name}); - // set outputs - auto out_var_handles = op_handler.Outputs(); - std::vector out_names; - for (const auto &out : out_var_handles) { - if (dynamic_cast(out) != nullptr) { - continue; - } - out_names.emplace_back(out->Name()); - } - desc2.SetOutput("Out", {name}); - - int ring_id = platform::NCCLCommContext::Instance().GetRingId( - dynamic_cast(&op_handler)->GetComm()); - desc2.SetAttr("ring_id", ring_id); - desc2.SetAttr("use_calc_stream", true); - // handle grad merge - if (dynamic_cast(&op_handler)) { - VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; - auto cond_name = - dynamic_cast(&op_handler) - ->GradMergeCondName(); - desc2.SetInput("Cond", {cond_name}); - } + ops->emplace_back(); + OpDesc &fuse_op_desc = ops->back(); + fuse_op_desc.SetType("check_memory_continue"); + fuse_op_desc.SetInput("X", in_names); + fuse_op_desc.SetOutput("Out", {all_reduce_var_name}); + fuse_op_desc.SetOutput("XOut", in_names); + fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + (static_cast(OpRole::kBackward))); + } else { + all_reduce_var_name = op_handle.Inputs()[0]->Name(); } - desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - (static_cast(OpRole::kBackward))); - desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - (static_cast(OpRole::kBackward))); + // add c_allreduce_sum OP + ops->emplace_back(); + OpDesc &all_reduce_op_desc = ops->back(); + all_reduce_op_desc.SetType("c_allreduce_sum"); + all_reduce_op_desc.SetInput("X", {all_reduce_var_name}); + all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name}); + + int ring_id = platform::NCCLCommContext::Instance().GetRingId( + dynamic_cast(&op_handle)->GetComm()); + all_reduce_op_desc.SetAttr("ring_id", ring_id); + all_reduce_op_desc.SetAttr("use_calc_stream", true); + all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + (static_cast(OpRole::kBackward))); + + // handle grad merge + if (dynamic_cast(&op_handle)) { + VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; + const std::string cond_name = + dynamic_cast(&op_handle) + ->GradMergeCondName(); + all_reduce_op_desc.SetInput("Cond", {cond_name}); + } else if (dynamic_cast(&op_handle)) { + VLOG(4) << "GradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; + const std::string cond_name = + dynamic_cast(&op_handle) + ->GradMergeCondName(); + all_reduce_op_desc.SetInput("Cond", {cond_name}); + } #else PADDLE_THROW( platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented " @@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector &nodes, for (Node *n : nodes) { // if node is not Op, skip if (!n->IsOp()) continue; - // create fill_constant op if (n->Name() == "scale_loss_grad") { VLOG(4) << "convert op node scale_loss_grad to desc fill_constant"; ops->emplace_back(); auto &desc = ops->back(); ReplaceScaleLossGradOp(*n, &desc); - } else if (n->Name() == "fused_all_reduce") { - VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum"; + } else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") { + VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum"; ReplaceAllReduceOp(*n, block, ops); VLOG(4) << n->ToString(); } else if (n->Op()) {