未验证 提交 249b55c5 编写于 作者: M MissPenguin 提交者: GitHub

add pass enhance for map_matmul_to_mul_pass and flatten2_matmul_fuse_… (#33463)

上级 77a880c0
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -26,6 +27,103 @@ namespace ir { ...@@ -26,6 +27,103 @@ namespace ir {
class Node; class Node;
MapMatmul2MulPass::MapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("flatten2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "map matmul to mul"; VLOG(4) << "map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
...@@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(mul_node, matmul_out); IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_op}); GraphSafeRemoveNodes(graph, {matmul_op});
++found_count; ++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed.";
return;
}
} }
}; };
...@@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse flatten2+matmul to mul"; VLOG(4) << "fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
...@@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(mul_node, matmul_out); IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op}); GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op});
++found_count; ++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed.";
return;
}
} }
}; };
......
...@@ -39,6 +39,7 @@ class Graph; ...@@ -39,6 +39,7 @@ class Graph;
class MapMatmul2MulPass : public FusePassBase { class MapMatmul2MulPass : public FusePassBase {
public: public:
MapMatmul2MulPass();
virtual ~MapMatmul2MulPass() {} virtual ~MapMatmul2MulPass() {}
protected: protected:
...@@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase { ...@@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase {
class Flatten2MatmulFusePass : public FusePassBase { class Flatten2MatmulFusePass : public FusePassBase {
public: public:
Flatten2MatmulFusePass();
virtual ~Flatten2MatmulFusePass() {} virtual ~Flatten2MatmulFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册