From e15009e744edf295358137315bbd627b8ac445ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Tue, 29 Jun 2021 16:20:29 +0800 Subject: [PATCH] add compat precondition for reshape_transpose_matmul_mkldnn_fuse_pass, (#33820) test=develop. --- ...shape_transpose_matmul_mkldnn_fuse_pass.cc | 58 +++++++++++++++++++ ...eshape_transpose_matmul_mkldnn_fuse_pass.h | 5 +- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc index b4c53ec5f9..26692849d9 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.cc @@ -23,6 +23,59 @@ namespace paddle { namespace framework { namespace ir { +ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + // The reshape2 op for this pass should not have "Shape" and "ShapeTensor" + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); +} + void ReshapeTransposeMatmulMkldnnFusePass::Fuse( Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const { GraphPatternDetector gpd; @@ -34,6 +87,11 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( int found_reshape_transpose_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Op compatible check in " + "reshape_transpose_matmul_mkldnn_fuse_pass failed."; + return; + } VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse"; GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); diff --git a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h index 7a53b3c498..4637d0659a 100644 --- a/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h @@ -17,8 +17,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -26,11 +24,10 @@ namespace ir { /* * Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn. */ -class Graph; class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { public: - virtual ~ReshapeTransposeMatmulMkldnnFusePass() {} + ReshapeTransposeMatmulMkldnnFusePass(); protected: void ApplyImpl(ir::Graph* graph) const override; -- GitLab