未验证 提交 b7ddd7d7 编写于 作者: C cc 提交者: GitHub

skip fuse repeated fc when the fc with weight padding (#32648)

上级 8ccf549b
...@@ -54,6 +54,17 @@ static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") { ...@@ -54,6 +54,17 @@ static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") {
return false; return false;
} }
static bool IsFCWithPaddingWeights(Node* n) {
bool res = false;
if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" &&
n->inputs.size() == 3U && n->outputs.size() == 1U) {
if (n->Op()->HasAttr("padding_weights")) {
res = BOOST_GET_CONST(bool, n->Op()->GetAttr("padding_weights"));
}
}
return res;
}
static bool IsParamOfFC(Node* n, const std::string& param_name) { static bool IsParamOfFC(Node* n, const std::string& param_name) {
if (IsInputOfFC(n) && n->inputs.empty() && if (IsInputOfFC(n) && n->inputs.empty() &&
(n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) {
...@@ -255,7 +266,7 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -255,7 +266,7 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern,
fc_ops[i] = pattern->NewNode( fc_ops[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
if (!IsFCWithAct(x, "relu")) { if (!IsFCWithAct(x, "relu") || IsFCWithPaddingWeights(x)) {
return false; return false;
} }
auto* fc_out_var = x->outputs[0]; auto* fc_out_var = x->outputs[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册