未验证 提交 4bf9e11f 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for matmul_transpose_reshape_fuse_pass, test=develop (#33614)

上级 c7e3c918
......@@ -22,6 +22,61 @@ namespace paddle {
namespace framework {
namespace ir {
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha") // unconstrained. can be any float value.
.End()
.AddAttr("transpose_X") // unconstrained. can be any bool value.
.End()
.AddAttr("transpose_Y") // unconstrained. can be any bool value.
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // ints
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // ints
.End();
}
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
......@@ -37,6 +92,10 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle matmul_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
......
......@@ -17,8 +17,6 @@
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -27,6 +25,7 @@ class Graph;
class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
public:
MatmulTransposeReshapeMKLDNNPass();
virtual ~MatmulTransposeReshapeMKLDNNPass() {}
protected:
......
......@@ -28,6 +28,7 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetOutput("Out", {outputs[0]});
if (type == "transpose2") {
op->SetAttr("axis", std::vector<int>({0, 2, 1, 3}));
op->SetAttr("data_format", std::string("NCHW"));
op->SetOutput("XShape", {outputs[1]});
}
if (type == "reshape2") {
......@@ -38,6 +39,9 @@ void SetOp(ProgramDesc *prog, const std::string &type,
if (type == "matmul") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("alpha", 1.0f);
op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true);
}
}
......
......@@ -3,15 +3,6 @@ def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
attrs {
name: "shape"
type: INTS
}
}
extra {
inputs {
name: "Shape"
}
......@@ -21,6 +12,15 @@ extra {
outputs {
name: "XShape"
}
outputs {
name: "Out"
}
attrs {
name: "shape"
type: INTS
}
}
extra {
attrs {
name: "use_quantizer"
type: BOOLEAN
......@@ -50,4 +50,3 @@ extra {
type: STRING
}
}
type: "transpose"
type: "transpose2"
def {
inputs {
name: "X"
......@@ -6,6 +6,9 @@ def {
outputs {
name: "Out"
}
outputs {
name: "XShape"
}
attrs {
name: "axis"
type: INTS
......@@ -16,9 +19,6 @@ def {
}
}
extra {
outputs {
name: "XShape"
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
......@@ -52,4 +52,3 @@ extra {
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册