未验证 提交 e15009e7 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for reshape_transpose_matmul_mkldnn_fuse_pass, (#33820)

test=develop.
上级 0bccd782
...@@ -23,6 +23,59 @@ namespace paddle { ...@@ -23,6 +23,59 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
}
void ReshapeTransposeMatmulMkldnnFusePass::Fuse( void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const { Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -34,6 +87,11 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -34,6 +87,11 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
int found_reshape_transpose_matmul_count = 0; int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in "
"reshape_transpose_matmul_mkldnn_fuse_pass failed.";
return;
}
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse"; VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,11 +24,10 @@ namespace ir { ...@@ -26,11 +24,10 @@ namespace ir {
/* /*
* Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn. * Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn.
*/ */
class Graph;
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public: public:
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} ReshapeTransposeMatmulMkldnnFusePass();
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册