From 55aea350418899362d90f0a5f14a57a9651f2bf2 Mon Sep 17 00:00:00 2001 From: Double_V Date: Mon, 28 Jun 2021 19:53:15 +0800 Subject: [PATCH] add fusepass Reshape2MatmulFusePass AdaptivePool2dConvertGlobalPass (#33555) * add transpose transpose opdef, test=develop * add line, test=develop * fix wrong name, test=develop * add pass, test=develop * fix bug, test=develop * fix bug, test=develop * delete limite about alpha, test=develop * add mul to reshape2MatmulFusePass, test=develop * add limit about alpha, test=develop * fix bug,test=develop * set adaptive as false and global_pooling as True, test=develop * set x_num_col_dims as 1, test=develop * fix reshape, add attr limit, test=develop * fix conflict,test=develop * fix comment, test=develop * fix comment,test=develop * fix comment,test=develop * ,test=develop * add IsType, test=develop * add IsType, test=develop --- .../ir/adaptive_pool2d_convert_global_pass.cc | 40 +++++++++++ .../ir/adaptive_pool2d_convert_global_pass.h | 1 + .../framework/ir/map_matmul_to_mul_pass.cc | 70 +++++++++++++++++++ .../framework/ir/map_matmul_to_mul_pass.h | 1 + 4 files changed, 112 insertions(+) diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc index 62d79f987a6..0e2bb3eaad5 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc @@ -24,6 +24,46 @@ namespace paddle { namespace framework { namespace ir { +AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { + AddOpCompat(OpCompat("pool2d")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("pooling_type") + .IsStringIn({"max", "avg"}) + .End() + .AddAttr("ksize") + .IsType>() + .End() + .AddAttr("global_pooling") + .IsBoolEQ(true) + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("exclusive") + .IsType() + .End() + .AddAttr("adaptive") + .IsBoolEQ(false) + .End() + .AddAttr("ceil_mode") + .IsType() + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW"}) + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End(); +} + void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { std::string name_scope = "adaptive_pool2d_convert_global_pass"; FusePassBase::Init(name_scope, graph); diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h index f16f030d518..4a1405004e2 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h @@ -31,6 +31,7 @@ class Graph; */ class AdaptivePool2dConvertGlobalPass : public FusePassBase { public: + AdaptivePool2dConvertGlobalPass(); virtual ~AdaptivePool2dConvertGlobalPass() {} protected: diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index 20761f2f1ea..72e6742f8f3 100755 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -267,6 +267,68 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } +Reshape2MatmulFusePass::Reshape2MatmulFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // ints + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGT(0.99999f) + .IsNumLT(1.00001f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ("False") + .End() + .AddAttr("transpose_Y") + .IsBoolEQ("False") + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -280,6 +342,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "fuse reshape2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); @@ -326,6 +392,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); } + if (!IsCompat(desc)) { + LOG(WARNING) << "reshape2 matmul pass in out mul op compat failed."; + return; + } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(reshape2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 27828f9c438..5dc5caae21e 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -96,6 +96,7 @@ class Squeeze2MatmulFusePass : public FusePassBase { class Reshape2MatmulFusePass : public FusePassBase { public: + Reshape2MatmulFusePass(); virtual ~Reshape2MatmulFusePass() {} protected: -- GitLab