diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index a552e42619f368c2e8e2a51213ac10d9317151cf..13f1fa50d080a33d837ebb63984cd4e5c3c1c350 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -28,6 +28,45 @@ namespace ir { class Graph; using string::PrettyLogDetail; +ScaleMatmulFusePass::ScaleMatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGT(0.0f) + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsNumGT(0.0f) + .End() + .AddAttr("bias") + .IsNumEQ(0.0f) + .End() + .AddAttr("bias_after_scale") + .IsOptional() + .IsType() + .End(); +} void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL(graph, @@ -43,6 +82,10 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_scale_matmul_fuse_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, scale_matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, scale_matmul_pattern); @@ -75,6 +118,11 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { matmul_op->Op()->SetInput(matmul_op_input_name, std::vector({scale_in->Name()})); IR_NODE_LINK_TO(scale_in, matmul_op); + + if (!IsCompat(*matmul_op->Op())) { + LOG(WARNING) << "scale_matmul_fuse_pass in out fc op compat failed."; + return; + } GraphSafeRemoveNodes(graph, {scale_op, scale_out}); found_scale_matmul_fuse_count++; } diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h index 32ff78d9a73683c700ceb31a1505538ff7ee6119..acea8ba563dc05ae1fb7b63afa0479cc27f74a31 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ScaleMatmulFusePass : public FusePassBase { public: + ScaleMatmulFusePass(); virtual ~ScaleMatmulFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc index d37d014a87b66076ec94ad69b381c6a73c7bca19..60f844ffc80cea2bd1fefca31435575936f5bdf5 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc @@ -31,6 +31,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, op->SetAttr("scale", scale); op->SetAttr("bias", bias); } else if (type == "matmul") { + op->SetAttr("transpose_X", false); + op->SetAttr("transpose_Y", false); op->SetInput("X", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); op->SetAttr("alpha", scale);