diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index 0fc458723ffe43040aa376e2389c950bd26c4c98..60d661f7740d02b5964976410832009cc25d3c51 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -91,6 +91,10 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, scale_matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, scale_matmul_pattern); + if ((scale_out->outputs).size() != 1) { + return; + } + if (scale_op->Op()->GetAttrIfExists("bias") == 0.0) { auto matmul_alpha = matmul_op->Op()->GetAttrIfExists("alpha"); auto scale_scale = scale_op->Op()->GetAttrIfExists("scale");