From 8c4c0725336a65ece054b95741d76aa205c65fa9 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Tue, 29 Jun 2021 18:50:52 +0800 Subject: [PATCH] Scale matmul fuse pass (#33803) * scale_matmul_fuse_pass_init * enhance scale_matmul_fuse_pass * change scale_matmul_fuse_pass unittest --- .../ir/mkldnn/scale_matmul_fuse_pass.cc | 48 +++++++++++++++++++ .../ir/mkldnn/scale_matmul_fuse_pass.h | 1 + .../mkldnn/scale_matmul_fuse_pass_tester.cc | 2 + 3 files changed, 51 insertions(+) diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index a552e42619f..13f1fa50d08 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -28,6 +28,45 @@ namespace ir { class Graph; using string::PrettyLogDetail; +ScaleMatmulFusePass::ScaleMatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGT(0.0f) + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("scale") + .IsNumGT(0.0f) + .End() + .AddAttr("bias") + .IsNumEQ(0.0f) + .End() + .AddAttr("bias_after_scale") + .IsOptional() + .IsType() + .End(); +} void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL(graph, @@ -43,6 +82,10 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_scale_matmul_fuse_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, scale_matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, scale_matmul_pattern); @@ -75,6 +118,11 @@ void ScaleMatmulFusePass::ApplyImpl(ir::Graph* graph) const { matmul_op->Op()->SetInput(matmul_op_input_name, std::vector({scale_in->Name()})); IR_NODE_LINK_TO(scale_in, matmul_op); + + if (!IsCompat(*matmul_op->Op())) { + LOG(WARNING) << "scale_matmul_fuse_pass in out fc op compat failed."; + return; + } GraphSafeRemoveNodes(graph, {scale_op, scale_out}); found_scale_matmul_fuse_count++; } diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h index 32ff78d9a73..acea8ba563d 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h @@ -24,6 +24,7 @@ class Graph; class ScaleMatmulFusePass : public FusePassBase { public: + ScaleMatmulFusePass(); virtual ~ScaleMatmulFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc index d37d014a87b..60f844ffc80 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass_tester.cc @@ -31,6 +31,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, op->SetAttr("scale", scale); op->SetAttr("bias", bias); } else if (type == "matmul") { + op->SetAttr("transpose_X", false); + op->SetAttr("transpose_Y", false); op->SetInput("X", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); op->SetAttr("alpha", scale); -- GitLab