From 249b55c55607b9075ad319edaa05db2a3a4dcf0e Mon Sep 17 00:00:00 2001 From: MissPenguin Date: Fri, 25 Jun 2021 19:15:51 +0800 Subject: [PATCH] =?UTF-8?q?add=20pass=20enhance=20for=20map=5Fmatmul=5Fto?= =?UTF-8?q?=5Fmul=5Fpass=20and=20flatten2=5Fmatmul=5Ffuse=5F=E2=80=A6=20(#?= =?UTF-8?q?33463)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../framework/ir/map_matmul_to_mul_pass.cc | 118 ++++++++++++++++++ .../framework/ir/map_matmul_to_mul_pass.h | 2 + 2 files changed, 120 insertions(+) mode change 100644 => 100755 paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc old mode 100644 new mode 100755 index c36123f65f6..20761f2f1ea --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -16,6 +16,7 @@ #include #include +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -26,6 +27,103 @@ namespace ir { class Node; +MapMatmul2MulPass::MapMatmul2MulPass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +Flatten2MatmulFusePass::Flatten2MatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("flatten2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(0) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "map matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); @@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(mul_node, matmul_out); GraphSafeRemoveNodes(graph, {matmul_op}); ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed."; + return; + } } }; @@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + VLOG(4) << "fuse flatten2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); @@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(mul_node, matmul_out); GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op}); ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed."; + return; + } } }; diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 85067a6f642..27828f9c438 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -39,6 +39,7 @@ class Graph; class MapMatmul2MulPass : public FusePassBase { public: + MapMatmul2MulPass(); virtual ~MapMatmul2MulPass() {} protected: @@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase { class Flatten2MatmulFusePass : public FusePassBase { public: + Flatten2MatmulFusePass(); virtual ~Flatten2MatmulFusePass() {} protected: -- GitLab