未验证 提交 34466911 编写于 作者: F feng_shuai 提交者: GitHub

squeeze2_matmul_fuse_pass init (#33805)

上级 8c4c0725
......@@ -124,6 +124,60 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
.End();
}
Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("Squeeze2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axes")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
......@@ -211,6 +265,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -260,6 +318,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed.";
return;
}
}
};
......
......@@ -67,6 +67,7 @@ class MapMatmul2MulPass : public FusePassBase {
class Squeeze2MatmulFusePass : public FusePassBase {
public:
Squeeze2MatmulFusePass();
virtual ~Squeeze2MatmulFusePass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册