未验证 提交 4bf9e11f 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for matmul_transpose_reshape_fuse_pass, test=develop (#33614)

上级 c7e3c918
...@@ -22,6 +22,61 @@ namespace paddle { ...@@ -22,6 +22,61 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha") // unconstrained. can be any float value.
.End()
.AddAttr("transpose_X") // unconstrained. can be any bool value.
.End()
.AddAttr("transpose_Y") // unconstrained. can be any bool value.
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // ints
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // ints
.End();
}
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -37,6 +92,10 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -37,6 +92,10 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
int found_matmul_transpose_reshape_count = 0; int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle matmul_transpose_reshape fuse"; VLOG(4) << "handle matmul_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,6 +25,7 @@ class Graph; ...@@ -27,6 +25,7 @@ class Graph;
class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
public: public:
MatmulTransposeReshapeMKLDNNPass();
virtual ~MatmulTransposeReshapeMKLDNNPass() {} virtual ~MatmulTransposeReshapeMKLDNNPass() {}
protected: protected:
......
...@@ -28,6 +28,7 @@ void SetOp(ProgramDesc *prog, const std::string &type, ...@@ -28,6 +28,7 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
if (type == "transpose2") { if (type == "transpose2") {
op->SetAttr("axis", std::vector<int>({0, 2, 1, 3})); op->SetAttr("axis", std::vector<int>({0, 2, 1, 3}));
op->SetAttr("data_format", std::string("NCHW"));
op->SetOutput("XShape", {outputs[1]}); op->SetOutput("XShape", {outputs[1]});
} }
if (type == "reshape2") { if (type == "reshape2") {
...@@ -38,6 +39,9 @@ void SetOp(ProgramDesc *prog, const std::string &type, ...@@ -38,6 +39,9 @@ void SetOp(ProgramDesc *prog, const std::string &type,
if (type == "matmul") { if (type == "matmul") {
op->SetInput("Y", {inputs[1]}); op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
op->SetAttr("alpha", 1.0f);
op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true);
} }
} }
......
...@@ -3,15 +3,6 @@ def { ...@@ -3,15 +3,6 @@ def {
inputs { inputs {
name: "X" name: "X"
} }
outputs {
name: "Out"
}
attrs {
name: "shape"
type: INTS
}
}
extra {
inputs { inputs {
name: "Shape" name: "Shape"
} }
...@@ -21,6 +12,15 @@ extra { ...@@ -21,6 +12,15 @@ extra {
outputs { outputs {
name: "XShape" name: "XShape"
} }
outputs {
name: "Out"
}
attrs {
name: "shape"
type: INTS
}
}
extra {
attrs { attrs {
name: "use_quantizer" name: "use_quantizer"
type: BOOLEAN type: BOOLEAN
...@@ -50,4 +50,3 @@ extra { ...@@ -50,4 +50,3 @@ extra {
type: STRING type: STRING
} }
} }
type: "transpose" type: "transpose2"
def { def {
inputs { inputs {
name: "X" name: "X"
...@@ -6,6 +6,9 @@ def { ...@@ -6,6 +6,9 @@ def {
outputs { outputs {
name: "Out" name: "Out"
} }
outputs {
name: "XShape"
}
attrs { attrs {
name: "axis" name: "axis"
type: INTS type: INTS
...@@ -16,9 +19,6 @@ def { ...@@ -16,9 +19,6 @@ def {
} }
} }
extra { extra {
outputs {
name: "XShape"
}
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
...@@ -52,4 +52,3 @@ extra { ...@@ -52,4 +52,3 @@ extra {
type: STRING type: STRING
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册