diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index d944da5bc4863048ca2bcbec11f3888191056e78..95fe979a3350f9ed1d9def94bcf5bd7775a7766b 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -298,7 +298,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, return last_out_var; } -static int BuildFusion(Graph* graph, const std::string& name_scope) { +static int BuildFusion(Graph* graph, const std::string& name_scope, + const SquaredMatSubFusePass* pass) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -320,6 +321,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { LOG(INFO) << "handle sqaure mat sub fuse"; + if (!pass->IsAcceptable(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + auto& fused_pattern = gpd.pattern(); auto* matx = retrieve_node(name_scope + "/x", subgraph, fused_pattern); @@ -368,14 +374,109 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { GraphSafeRemoveNodes(graph, marked_nodes); ++fusion_count; }; - gpd(graph, handler); return fusion_count; } +SquaredMatSubFusePass::SquaredMatSubFusePass() { + AddOpCompat(OpCompat("square")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("elementwise_sub")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(-1) + .End(); + + AddOpCompat(OpCompat("elementwise_mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(-1) + .End(); + + AddOpCompat(OpCompat("fill_constant")) + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("dtype") + .IsNumGE(0) + .IsNumLE(25) + .End() + .AddAttr("shape") + .End() + // type:float,there is no restriction + .AddAttr("value") + .End(); +} + +// to use IsCompat +bool SquaredMatSubFusePass::IsAcceptable( + const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const { + return IsCompat(subgraph, g); +} + void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(name_scope_, graph); - int fusion_count = BuildFusion(graph, name_scope_); + int fusion_count = BuildFusion(graph, name_scope_, this); AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h index 90def957df4bf0907a306798fbb1e9ba53c37919..fcc5b309157f082b1ccfaa4011f1ee78bd22f7ef 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h @@ -31,11 +31,13 @@ class Graph; class SquaredMatSubFusePass : public FusePassBase { public: + SquaredMatSubFusePass(); + bool IsAcceptable(const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) const; virtual ~SquaredMatSubFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; - const std::string name_scope_{"squared_mat_sub_fuse"}; }; diff --git a/paddle/fluid/operators/compat/elementwise_mul.pbtxt b/paddle/fluid/operators/compat/elementwise_mul.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3bc2186ba30e9ea3399b40ced85d3873f2361440 --- /dev/null +++ b/paddle/fluid/operators/compat/elementwise_mul.pbtxt @@ -0,0 +1,70 @@ +type: "elementwise_mul" +def { + inputs { + name: "X" + } + inputs { + name: "Y" + } + outputs { + name: "Out" + } + attrs { + name: "axis" + type: INT + } +} +extra { + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "x_data_format" + type: STRING + } + attrs { + name: "y_data_format" + type: STRING + } + attrs { + name: "use_quantizer" + type: BOOLEAN + } + attrs { + name: "mkldnn_data_type" + type: STRING + } + attrs { + name: "Scale_x" + type: FLOAT + } + attrs { + name: "Scale_y" + type: FLOAT + } + attrs { + name: "Scale_out" + type: FLOAT + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/paddle/fluid/operators/compat/fill_constant.pbtxt b/paddle/fluid/operators/compat/fill_constant.pbtxt index b525da04a0d88b621bb8fe11ea4ecf5929921822..308348fd7e30deff244264fb7980c390b621e1ea 100644 --- a/paddle/fluid/operators/compat/fill_constant.pbtxt +++ b/paddle/fluid/operators/compat/fill_constant.pbtxt @@ -24,12 +24,13 @@ def { name: "value" type: FLOAT } - attrs { + +} +extra { + attrs { name: "str_value" type: STRING } -} -extra { attrs { name: "force_cpu" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/square.pbtxt b/paddle/fluid/operators/compat/square.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1a4f0640bec79a1e1a75026b90113cdef7650b5f --- /dev/null +++ b/paddle/fluid/operators/compat/square.pbtxt @@ -0,0 +1,44 @@ +type: "square" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } +} + +extra { + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "use_cudnn" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +}