diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index e29525cb8cd04b245a516d7f3e6970ba570c9c37..c0f17af3160ccd24ce83c97448edb7c7bc0958e2 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -23,6 +23,13 @@ namespace paddle { namespace framework { namespace ir { +AttrCompat& AttrCompat::IsStringEQ(const std::string& value) { + conditions_.emplace_back([value](const Attribute& attr) -> bool { + return value == BOOST_GET_CONST(std::string, attr); + }); + return *this; +} + AttrCompat& AttrCompat::IsStringIn(const std::set& candidates) { conditions_.emplace_back([candidates](const Attribute& attr) -> bool { std::string value = BOOST_GET_CONST(std::string, attr); diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h index 1fb7339a24b6bc2abf64de0e3ce7adced10e374d..cfec1f123e238e249f7b76004b916491b347f3bd 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.h +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -37,6 +37,8 @@ class AttrCompat { // @{ String-related methods //! Assert the attribute is an string in the `candidates` domain. + AttrCompat& IsStringEQ(const std::string& value); + //! Assert the attribute is an string in the `candidates` domain. AttrCompat& IsStringIn(const std::set& candidates); //! Assert the attribute is a string and match a custom judging function. AttrCompat& IsStringMatch( diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 4c87b63625c1f69c09588c5bb8483ab03616f153..a03a6f5b2c72c6e7d33c92e11915c15578f54b07 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -31,6 +31,27 @@ namespace paddle { namespace framework { namespace ir { +RepeatedFCReluFusePass::RepeatedFCReluFusePass() { + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("activation_type") + .IsStringEQ("relu") + .End(); +} static bool IsInputOfFC(Node* n) { if (n && n->IsVar() && VarLinksToOp(n, "fc")) { return true; @@ -295,8 +316,9 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, } } -static int BuildFusion(Graph* graph, const std::string& name_scope, - int num_fc) { +int RepeatedFCReluFusePass::BuildFusion(Graph* graph, + const std::string& name_scope, + int num_fc) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); BuildRepeatedFCReluPattern(pattern, name_scope, num_fc); @@ -316,6 +338,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "repeated_fc_relu_fuse_pass failed in op compat."; + return; + } LOG(INFO) << "handle Repeated FC Act fuse"; std::vector weights_vars(num_fc); std::vector bias_vars(num_fc); diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h index 0be217cc748a248f4e5bf8d98922cb8ebdbd3e3c..b2933d26e07ab7a981649fd84c275ce6ddecfce8 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h @@ -31,12 +31,16 @@ class Graph; class RepeatedFCReluFusePass : public FusePassBase { public: - virtual ~RepeatedFCReluFusePass() {} + RepeatedFCReluFusePass(); protected: void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"repeated_fc_relu_fuse"}; + + private: + int BuildFusion(Graph* graph, const std::string& name_scope, + int num_fc) const; }; } // namespace ir