未验证 提交 d27f15ed 编写于 作者: S sprouteer 提交者: GitHub

[XPU] Fix the out_max of the branch in xpu_conv2d op(#53343)

上级 b02de1b6
......@@ -87,7 +87,7 @@ class LinkXPUOpMaxPass : public FusePassBase {
const std::string& op_type,
bool with_branch) const;
const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"};
const std::string name_scope_{"link_xpu_op_max_pass"};
// ops with x_max/out_max
std::set<std::string> op_types_{"fc_xpu", "conv2d_xpu"};
};
......@@ -157,7 +157,15 @@ void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph,
GET_IR_NODE(branch);
auto* fusion_op_desc = fusion_op->Op();
if (input->inputs[0]->Op()->HasOutput("out_max")) {
if (fusion_op_desc->HasAttr("has_branch")) {
bool fusion_op_branch =
PADDLE_GET_CONST(bool, fusion_op_desc->GetAttr("has_branch"));
if (fusion_op_branch != with_branch) {
return;
}
}
if (input->inputs.size() > 0 && input->inputs[0]->IsOp() &&
input->inputs[0]->Op()->HasOutput("out_max")) {
auto input_max_name = input->inputs[0]->Op()->Output("out_max");
for (auto max_node : input->inputs[0]->outputs) {
if (input_max_name[0] == max_node->Name()) {
......@@ -169,7 +177,8 @@ void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph,
}
if (with_branch) {
if (branch->inputs[0]->Op()->HasOutput("out_max")) {
if (branch->inputs.size() > 0 && branch->inputs[0]->IsOp() &&
branch->inputs[0]->Op()->HasOutput("out_max")) {
auto branch_max_name = branch->inputs[0]->Op()->Output("out_max");
for (auto max_node : branch->inputs[0]->outputs) {
if (branch_max_name[0] == max_node->Name()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册