diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc index b4c53ec5f91ccb855d176f84cd12378d2ec66e26..26692849d977b5bc0e3dabbd35b7f8fa53832978 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -23,6 +23,59 @@ namespace paddle { namespace framework { namespace ir { +ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); +} + void ReshapeTransposeMatmulMkldnnFusePass::Fuse( Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const { GraphPatternDetector gpd; @@ -34,6 +87,11 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( int found_reshape_transpose_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Op compatible check in " + "reshape_transpose_matmul_mkldnn_fuse_pass failed."; + return; + } VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse"; GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h index 7a53b3c498413e43eea7b2e4697791d36fed1149..4637d0659af8c562440c280efb158f0fcde93f24 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h @@ -17,8 +17,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -26,11 +24,10 @@ namespace ir { /* * Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn. */ -class Graph; class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { public: - virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} + ReshapeTransposeMatmulMkldnnFusePass(); protected: void ApplyImpl(ir::Graph* graph) const override;