未验证 提交 249b55c5 编写于 作者: M MissPenguin 提交者: GitHub

add pass enhance for map_matmul_to_mul_pass and flatten2_matmul_fuse_… (#33463)

上级 77a880c0
......@@ -16,6 +16,7 @@
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -26,6 +27,103 @@ namespace ir {
class Node;
MapMatmul2MulPass::MapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("flatten2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
......@@ -39,6 +137,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
......@@ -82,6 +185,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed.";
return;
}
}
};
......@@ -244,6 +352,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
......@@ -301,6 +414,11 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed.";
return;
}
}
};
......
......@@ -39,6 +39,7 @@ class Graph;
class MapMatmul2MulPass : public FusePassBase {
public:
MapMatmul2MulPass();
virtual ~MapMatmul2MulPass() {}
protected:
......@@ -103,6 +104,7 @@ class Reshape2MatmulFusePass : public FusePassBase {
class Flatten2MatmulFusePass : public FusePassBase {
public:
Flatten2MatmulFusePass();
virtual ~Flatten2MatmulFusePass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册