未验证 提交 2c4cc68f 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for repeated_fc_relu_fuse_pass,test=develop. (#33742)

上级 98d25314
...@@ -23,6 +23,13 @@ namespace paddle { ...@@ -23,6 +23,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
AttrCompat& AttrCompat::IsStringEQ(const std::string& value) {
conditions_.emplace_back([value](const Attribute& attr) -> bool {
return value == BOOST_GET_CONST(std::string, attr);
});
return *this;
}
AttrCompat& AttrCompat::IsStringIn(const std::set<std::string>& candidates) { AttrCompat& AttrCompat::IsStringIn(const std::set<std::string>& candidates) {
conditions_.emplace_back([candidates](const Attribute& attr) -> bool { conditions_.emplace_back([candidates](const Attribute& attr) -> bool {
std::string value = BOOST_GET_CONST(std::string, attr); std::string value = BOOST_GET_CONST(std::string, attr);
......
...@@ -37,6 +37,8 @@ class AttrCompat { ...@@ -37,6 +37,8 @@ class AttrCompat {
// @{ String-related methods // @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain. //! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringEQ(const std::string& value);
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringIn(const std::set<std::string>& candidates); AttrCompat& IsStringIn(const std::set<std::string>& candidates);
//! Assert the attribute is a string and match a custom judging function. //! Assert the attribute is a string and match a custom judging function.
AttrCompat& IsStringMatch( AttrCompat& IsStringMatch(
......
...@@ -31,6 +31,27 @@ namespace paddle { ...@@ -31,6 +31,27 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
RepeatedFCReluFusePass::RepeatedFCReluFusePass() {
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("activation_type")
.IsStringEQ("relu")
.End();
}
static bool IsInputOfFC(Node* n) { static bool IsInputOfFC(Node* n) {
if (n && n->IsVar() && VarLinksToOp(n, "fc")) { if (n && n->IsVar() && VarLinksToOp(n, "fc")) {
return true; return true;
...@@ -295,8 +316,9 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -295,8 +316,9 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern,
} }
} }
static int BuildFusion(Graph* graph, const std::string& name_scope, int RepeatedFCReluFusePass::BuildFusion(Graph* graph,
int num_fc) { const std::string& name_scope,
int num_fc) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
BuildRepeatedFCReluPattern(pattern, name_scope, num_fc); BuildRepeatedFCReluPattern(pattern, name_scope, num_fc);
...@@ -316,6 +338,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -316,6 +338,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int fusion_count{0}; int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "repeated_fc_relu_fuse_pass failed in op compat.";
return;
}
LOG(INFO) << "handle Repeated FC Act fuse"; LOG(INFO) << "handle Repeated FC Act fuse";
std::vector<Node*> weights_vars(num_fc); std::vector<Node*> weights_vars(num_fc);
std::vector<Node*> bias_vars(num_fc); std::vector<Node*> bias_vars(num_fc);
......
...@@ -31,12 +31,16 @@ class Graph; ...@@ -31,12 +31,16 @@ class Graph;
class RepeatedFCReluFusePass : public FusePassBase { class RepeatedFCReluFusePass : public FusePassBase {
public: public:
virtual ~RepeatedFCReluFusePass() {} RepeatedFCReluFusePass();
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"repeated_fc_relu_fuse"}; const std::string name_scope_{"repeated_fc_relu_fuse"};
private:
int BuildFusion(Graph* graph, const std::string& name_scope,
int num_fc) const;
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册