From cfadd6b6cb3a3a5431cca1df66fc8ee9af7ddda5 Mon Sep 17 00:00:00 2001 From: wz1qqx <55830058+wz1qqx@users.noreply.github.com> Date: Mon, 29 May 2023 10:52:04 +0800 Subject: [PATCH] fix ssdet (#54136) --- .../framework/ir/xpu/link_xpu_op_max_pass.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc index 31e8618e077..d6938421428 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -133,10 +133,9 @@ void LinkXPUOpMaxPass::LinkAddActMax(ir::Graph* graph) const { } } } - auto* ele_y_pre_op = ele_y->inputs[0]->Op(); if (ele_y->inputs.size() > 0 && ele_y->inputs[0]->IsOp() && - ele_y_pre_op->HasOutput("out_max")) { - auto preop_max_var_name = ele_y_pre_op->Output("out_max"); + ele_y->inputs[0]->Op()->HasOutput("out_max")) { + auto preop_max_var_name = ele_y->inputs[0]->Op()->Output("out_max"); for (auto max_node : ele_y->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { fusion_op_desc->SetInput("y_max", {max_node->Name()}); @@ -166,6 +165,13 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { GET_IR_NODE(x); GET_IR_NODE(branch); auto* fusion_op_desc = fusion_op->Op(); + if (fusion_op_desc->HasAttr("has_branch")) { + bool fusion_op_branch = + PADDLE_GET_CONST(bool, fusion_op_desc->GetAttr("has_branch")); + if (fusion_op_branch != with_branch) { + return; + } + } auto* x_pre_op = x->inputs[0]->Op(); if (x->inputs.size() > 0 && x->inputs[0]->IsOp() && x_pre_op->HasOutput("out_max")) { @@ -178,10 +184,9 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { } } if (with_branch) { - auto* branch_pre_op = branch->inputs[0]->Op(); if (branch->inputs.size() > 0 && branch->inputs[0]->IsOp() && - branch_pre_op->HasOutput("out_max")) { - auto preop_max_var_name = branch_pre_op->Output("out_max"); + branch->inputs[0]->Op()->HasOutput("out_max")) { + auto preop_max_var_name = branch->inputs[0]->Op()->Output("out_max"); for (auto max_node : branch->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { fusion_op_desc->SetInput("branch_max", {max_node->Name()}); -- GitLab