diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc index 62d79f987a6702e4240b44e49af4ff047173505f..0e2bb3eaad536fd9e3556f640b76e591bbf2f988 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc @@ -24,6 +24,46 @@ namespace paddle { namespace framework { namespace ir { +AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { + AddOpCompat(OpCompat("pool2d")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("pooling_type") + .IsStringIn({"max", "avg"}) + .End() + .AddAttr("ksize") + .IsType>() + .End() + .AddAttr("global_pooling") + .IsBoolEQ(true) + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("exclusive") + .IsType() + .End() + .AddAttr("adaptive") + .IsBoolEQ(false) + .End() + .AddAttr("ceil_mode") + .IsType() + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW"}) + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End(); +} + void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { std::string name_scope = "adaptive_pool2d_convert_global_pass"; FusePassBase::Init(name_scope, graph); diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h index f16f030d518d02a43e9d0462ccab83f313a1dc34..4a1405004e247dff69635f7ebd766ae030da82e5 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h @@ -31,6 +31,7 @@ class Graph; */ class AdaptivePool2dConvertGlobalPass : public FusePassBase { public: + AdaptivePool2dConvertGlobalPass(); virtual ~AdaptivePool2dConvertGlobalPass() {} protected: diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index 20761f2f1eacba2a50aea028fef8c0f992ebf8d3..72e6742f8f34c1745637f1d0b492edadf3048bc4 100755 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -267,6 +267,68 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } +Reshape2MatmulFusePass::Reshape2MatmulFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // ints + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGT(0.99999f) + .IsNumLT(1.00001f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ("False") + .End() + .AddAttr("transpose_Y") + .IsBoolEQ("False") + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -280,6 +342,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "fuse reshape2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); @@ -326,6 +392,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); } + if (!IsCompat(desc)) { + LOG(WARNING) << "reshape2 matmul pass in out mul op compat failed."; + return; + } auto mul_node = g->CreateOpNode(&desc); IR_NODE_LINK_TO(reshape2_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node); diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h index 27828f9c43829ca66f1a513b8984c9366755babc..5dc5caae21ea9895fb30647fa3228317c43fe7ba 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h @@ -96,6 +96,7 @@ class Squeeze2MatmulFusePass : public FusePassBase { class Reshape2MatmulFusePass : public FusePassBase { public: + Reshape2MatmulFusePass(); virtual ~Reshape2MatmulFusePass() {} protected: