未验证 提交 cfadd6b6 编写于 作者: W wz1qqx 提交者: GitHub

fix ssdet (#54136)

上级 64661927
......@@ -133,10 +133,9 @@ void LinkXPUOpMaxPass::LinkAddActMax(ir::Graph* graph) const {
}
}
}
auto* ele_y_pre_op = ele_y->inputs[0]->Op();
if (ele_y->inputs.size() > 0 && ele_y->inputs[0]->IsOp() &&
ele_y_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = ele_y_pre_op->Output("out_max");
ele_y->inputs[0]->Op()->HasOutput("out_max")) {
auto preop_max_var_name = ele_y->inputs[0]->Op()->Output("out_max");
for (auto max_node : ele_y->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("y_max", {max_node->Name()});
......@@ -166,6 +165,13 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const {
GET_IR_NODE(x);
GET_IR_NODE(branch);
auto* fusion_op_desc = fusion_op->Op();
if (fusion_op_desc->HasAttr("has_branch")) {
bool fusion_op_branch =
PADDLE_GET_CONST(bool, fusion_op_desc->GetAttr("has_branch"));
if (fusion_op_branch != with_branch) {
return;
}
}
auto* x_pre_op = x->inputs[0]->Op();
if (x->inputs.size() > 0 && x->inputs[0]->IsOp() &&
x_pre_op->HasOutput("out_max")) {
......@@ -178,10 +184,9 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const {
}
}
if (with_branch) {
auto* branch_pre_op = branch->inputs[0]->Op();
if (branch->inputs.size() > 0 && branch->inputs[0]->IsOp() &&
branch_pre_op->HasOutput("out_max")) {
auto preop_max_var_name = branch_pre_op->Output("out_max");
branch->inputs[0]->Op()->HasOutput("out_max")) {
auto preop_max_var_name = branch->inputs[0]->Op()->Output("out_max");
for (auto max_node : branch->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("branch_max", {max_node->Name()});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册