From 4bf9e11ff7ee16ae3cb1883a698b1548d716dd55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 17 Jun 2021 18:20:18 +0800 Subject: [PATCH] add compat precondition for matmul_transpose_reshape_fuse_pass, test=develop (#33614) --- .../matmul_transpose_reshape_fuse_pass.cc | 59 +++++++++++++++++++ .../matmul_transpose_reshape_fuse_pass.h | 3 +- ...tmul_transpose_reshape_fuse_pass_tester.cc | 4 ++ paddle/fluid/operators/compat/reshape2.pbtxt | 19 +++--- .../{transpose.pdtxt => transpose.pbtxt} | 0 .../{transpose2.pdtxt => transpose2.pbtxt} | 9 ++- 6 files changed, 77 insertions(+), 17 deletions(-) rename paddle/fluid/operators/compat/{transpose.pdtxt => transpose.pbtxt} (100%) rename paddle/fluid/operators/compat/{transpose2.pdtxt => transpose2.pbtxt} (97%) diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index fbc97a0a929..1f17a741f19 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -22,6 +22,61 @@ namespace paddle { namespace framework { namespace ir { +MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") // unconstrained. can be any float value. + .End() + .AddAttr("transpose_X") // unconstrained. can be any bool value. + .End() + .AddAttr("transpose_Y") // unconstrained. can be any bool value. + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") // ints + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // ints + .End(); +} void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::InvalidArgument( @@ -37,6 +92,10 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { int found_matmul_transpose_reshape_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle matmul_transpose_reshape fuse"; GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h index ef469bac40c..09cbe9bdf7b 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h @@ -17,8 +17,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -27,6 +25,7 @@ class Graph; class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { public: + MatmulTransposeReshapeMKLDNNPass(); virtual ~MatmulTransposeReshapeMKLDNNPass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc index 122a7f802a5..ac4e6c383da 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc @@ -28,6 +28,7 @@ void SetOp(ProgramDesc *prog, const std::string &type, op->SetOutput("Out", {outputs[0]}); if (type == "transpose2") { op->SetAttr("axis", std::vector({0, 2, 1, 3})); + op->SetAttr("data_format", std::string("NCHW")); op->SetOutput("XShape", {outputs[1]}); } if (type == "reshape2") { @@ -38,6 +39,9 @@ void SetOp(ProgramDesc *prog, const std::string &type, if (type == "matmul") { op->SetInput("Y", {inputs[1]}); op->SetAttr("use_mkldnn", true); + op->SetAttr("alpha", 1.0f); + op->SetAttr("transpose_X", true); + op->SetAttr("transpose_Y", true); } } diff --git a/paddle/fluid/operators/compat/reshape2.pbtxt b/paddle/fluid/operators/compat/reshape2.pbtxt index 2ccc83305ba..d975aed61fa 100644 --- a/paddle/fluid/operators/compat/reshape2.pbtxt +++ b/paddle/fluid/operators/compat/reshape2.pbtxt @@ -3,15 +3,6 @@ def { inputs { name: "X" } - outputs { - name: "Out" - } - attrs { - name: "shape" - type: INTS - } -} -extra { inputs { name: "Shape" } @@ -21,6 +12,15 @@ extra { outputs { name: "XShape" } + outputs { + name: "Out" + } + attrs { + name: "shape" + type: INTS + } +} +extra { attrs { name: "use_quantizer" type: BOOLEAN @@ -50,4 +50,3 @@ extra { type: STRING } } - diff --git a/paddle/fluid/operators/compat/transpose.pdtxt b/paddle/fluid/operators/compat/transpose.pbtxt similarity index 100% rename from paddle/fluid/operators/compat/transpose.pdtxt rename to paddle/fluid/operators/compat/transpose.pbtxt diff --git a/paddle/fluid/operators/compat/transpose2.pdtxt b/paddle/fluid/operators/compat/transpose2.pbtxt similarity index 97% rename from paddle/fluid/operators/compat/transpose2.pdtxt rename to paddle/fluid/operators/compat/transpose2.pbtxt index 34fad62a101..19d991a6414 100644 --- a/paddle/fluid/operators/compat/transpose2.pdtxt +++ b/paddle/fluid/operators/compat/transpose2.pbtxt @@ -1,4 +1,4 @@ -type: "transpose" +type: "transpose2" def { inputs { name: "X" @@ -6,6 +6,9 @@ def { outputs { name: "Out" } + outputs { + name: "XShape" + } attrs { name: "axis" type: INTS @@ -16,9 +19,6 @@ def { } } extra { - outputs { - name: "XShape" - } attrs { name: "use_mkldnn" type: BOOLEAN @@ -52,4 +52,3 @@ extra { type: STRING } } - -- GitLab