diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index fbc97a0a929c48c4eba3baa881061654dd802b62..1f17a741f190941d352e9ad6346dfdbeca671b50 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -22,6 +22,61 @@ namespace paddle { namespace framework { namespace ir { +MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") // unconstrained. can be any float value. + .End() + .AddAttr("transpose_X") // unconstrained. can be any bool value. + .End() + .AddAttr("transpose_Y") // unconstrained. can be any bool value. + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") // ints + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // ints + .End(); +} void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL(graph, platform::errors::InvalidArgument( @@ -37,6 +92,10 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { int found_matmul_transpose_reshape_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle matmul_transpose_reshape fuse"; GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h index ef469bac40c4edbc524ef4b24c8df932819f0a3a..09cbe9bdf7b2fb5c8fd0c8676730031482f3d6d9 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h @@ -17,8 +17,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -27,6 +25,7 @@ class Graph; class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { public: + MatmulTransposeReshapeMKLDNNPass(); virtual ~MatmulTransposeReshapeMKLDNNPass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc index 122a7f802a52972612e2879eaea29d14e5d7c561..ac4e6c383dad9d5cc11e5bbce5f24093f9d60d24 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc @@ -28,6 +28,7 @@ void SetOp(ProgramDesc *prog, const std::string &type, op->SetOutput("Out", {outputs[0]}); if (type == "transpose2") { op->SetAttr("axis", std::vector({0, 2, 1, 3})); + op->SetAttr("data_format", std::string("NCHW")); op->SetOutput("XShape", {outputs[1]}); } if (type == "reshape2") { @@ -38,6 +39,9 @@ void SetOp(ProgramDesc *prog, const std::string &type, if (type == "matmul") { op->SetInput("Y", {inputs[1]}); op->SetAttr("use_mkldnn", true); + op->SetAttr("alpha", 1.0f); + op->SetAttr("transpose_X", true); + op->SetAttr("transpose_Y", true); } } diff --git a/paddle/fluid/operators/compat/reshape2.pbtxt b/paddle/fluid/operators/compat/reshape2.pbtxt index 2ccc83305baca9a2979fcd37420abfd945a35123..d975aed61fa1b7a4f2aba08d353d042d21c2dccb 100644 --- a/paddle/fluid/operators/compat/reshape2.pbtxt +++ b/paddle/fluid/operators/compat/reshape2.pbtxt @@ -3,15 +3,6 @@ def { inputs { name: "X" } - outputs { - name: "Out" - } - attrs { - name: "shape" - type: INTS - } -} -extra { inputs { name: "Shape" } @@ -21,6 +12,15 @@ extra { outputs { name: "XShape" } + outputs { + name: "Out" + } + attrs { + name: "shape" + type: INTS + } +} +extra { attrs { name: "use_quantizer" type: BOOLEAN @@ -50,4 +50,3 @@ extra { type: STRING } } - diff --git a/paddle/fluid/operators/compat/transpose.pdtxt b/paddle/fluid/operators/compat/transpose.pbtxt similarity index 100% rename from paddle/fluid/operators/compat/transpose.pdtxt rename to paddle/fluid/operators/compat/transpose.pbtxt diff --git a/paddle/fluid/operators/compat/transpose2.pdtxt b/paddle/fluid/operators/compat/transpose2.pbtxt similarity index 97% rename from paddle/fluid/operators/compat/transpose2.pdtxt rename to paddle/fluid/operators/compat/transpose2.pbtxt index 34fad62a101e0de1bed8e671cb454396f865b421..19d991a6414d131c4833d5b919e9372b38168864 100644 --- a/paddle/fluid/operators/compat/transpose2.pdtxt +++ b/paddle/fluid/operators/compat/transpose2.pbtxt @@ -1,4 +1,4 @@ -type: "transpose" +type: "transpose2" def { inputs { name: "X" @@ -6,6 +6,9 @@ def { outputs { name: "Out" } + outputs { + name: "XShape" + } attrs { name: "axis" type: INTS @@ -16,9 +19,6 @@ def { } } extra { - outputs { - name: "XShape" - } attrs { name: "use_mkldnn" type: BOOLEAN @@ -52,4 +52,3 @@ extra { type: STRING } } -