未验证 提交 2c4cc68f 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for repeated_fc_relu_fuse_pass,test=develop. (#33742)

上级 98d25314
......@@ -23,6 +23,13 @@ namespace paddle {
namespace framework {
namespace ir {
AttrCompat& AttrCompat::IsStringEQ(const std::string& value) {
conditions_.emplace_back([value](const Attribute& attr) -> bool {
return value == BOOST_GET_CONST(std::string, attr);
});
return *this;
}
AttrCompat& AttrCompat::IsStringIn(const std::set<std::string>& candidates) {
conditions_.emplace_back([candidates](const Attribute& attr) -> bool {
std::string value = BOOST_GET_CONST(std::string, attr);
......
......@@ -37,6 +37,8 @@ class AttrCompat {
// @{ String-related methods
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringEQ(const std::string& value);
//! Assert the attribute is an string in the `candidates` domain.
AttrCompat& IsStringIn(const std::set<std::string>& candidates);
//! Assert the attribute is a string and match a custom judging function.
AttrCompat& IsStringMatch(
......
......@@ -31,6 +31,27 @@ namespace paddle {
namespace framework {
namespace ir {
RepeatedFCReluFusePass::RepeatedFCReluFusePass() {
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("activation_type")
.IsStringEQ("relu")
.End();
}
static bool IsInputOfFC(Node* n) {
if (n && n->IsVar() && VarLinksToOp(n, "fc")) {
return true;
......@@ -295,8 +316,9 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern,
}
}
static int BuildFusion(Graph* graph, const std::string& name_scope,
int num_fc) {
int RepeatedFCReluFusePass::BuildFusion(Graph* graph,
const std::string& name_scope,
int num_fc) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildRepeatedFCReluPattern(pattern, name_scope, num_fc);
......@@ -316,6 +338,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "repeated_fc_relu_fuse_pass failed in op compat.";
return;
}
LOG(INFO) << "handle Repeated FC Act fuse";
std::vector<Node*> weights_vars(num_fc);
std::vector<Node*> bias_vars(num_fc);
......
......@@ -31,12 +31,16 @@ class Graph;
class RepeatedFCReluFusePass : public FusePassBase {
public:
virtual ~RepeatedFCReluFusePass() {}
RepeatedFCReluFusePass();
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"repeated_fc_relu_fuse"};
private:
int BuildFusion(Graph* graph, const std::string& name_scope,
int num_fc) const;
};
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册