From 2c4cc68f822525f3a733dd7ec3f2198b953a09b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Fri, 25 Jun 2021 10:24:04 +0800 Subject: [PATCH] add compat precondition for repeated_fc_relu_fuse_pass,test=develop. (#33742) --- .../framework/ir/op_compat_sensible_pass.cc | 7 +++++ .../framework/ir/op_compat_sensible_pass.h | 2 ++ .../ir/repeated_fc_relu_fuse_pass.cc | 30 +++++++++++++++++-- .../framework/ir/repeated_fc_relu_fuse_pass.h | 6 +++- 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index e29525cb8cd..c0f17af3160 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -23,6 +23,13 @@ namespace paddle { namespace framework { namespace ir { +AttrCompat& AttrCompat::IsStringEQ(const std::string& value) { + conditions_.emplace_back([value](const Attribute& attr) -> bool { + return value == BOOST_GET_CONST(std::string, attr); + }); + return *this; +} + AttrCompat& AttrCompat::IsStringIn(const std::set& candidates) { conditions_.emplace_back([candidates](const Attribute& attr) -> bool { std::string value = BOOST_GET_CONST(std::string, attr); diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index 1fb7339a24b..cfec1f123e2 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -37,6 +37,8 @@ class AttrCompat { // @{ String-related methods //! Assert the attribute is an string in the `candidates` domain. + AttrCompat& IsStringEQ(const std::string& value); + //! Assert the attribute is an string in the `candidates` domain. AttrCompat& IsStringIn(const std::set& candidates); //! Assert the attribute is a string and match a custom judging function. AttrCompat& IsStringMatch( diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 4c87b63625c..a03a6f5b2c7 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -31,6 +31,27 @@ namespace paddle { namespace framework { namespace ir { +RepeatedFCReluFusePass::RepeatedFCReluFusePass() { + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("activation_type") + .IsStringEQ("relu") + .End(); +} static bool IsInputOfFC(Node* n) { if (n && n->IsVar() && VarLinksToOp(n, "fc")) { return true; @@ -295,8 +316,9 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, } } -static int BuildFusion(Graph* graph, const std::string& name_scope, - int num_fc) { +int RepeatedFCReluFusePass::BuildFusion(Graph* graph, + const std::string& name_scope, + int num_fc) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); BuildRepeatedFCReluPattern(pattern, name_scope, num_fc); @@ -316,6 +338,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "repeated_fc_relu_fuse_pass failed in op compat."; + return; + } LOG(INFO) << "handle Repeated FC Act fuse"; std::vector weights_vars(num_fc); std::vector bias_vars(num_fc); diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h index 0be217cc748..b2933d26e07 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h @@ -31,12 +31,16 @@ class Graph; class RepeatedFCReluFusePass : public FusePassBase { public: - virtual ~RepeatedFCReluFusePass() {} + RepeatedFCReluFusePass(); protected: void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"repeated_fc_relu_fuse"}; + + private: + int BuildFusion(Graph* graph, const std::string& name_scope, + int num_fc) const; }; } // namespace ir -- GitLab