From b7ddd7d7a18dc270a84f7bb64f3c3e1a79b676ce Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Thu, 29 Apr 2021 11:26:26 +0800 Subject: [PATCH] skip fuse repeated fc when the fc with weight padding (#32648) --- .../framework/ir/repeated_fc_relu_fuse_pass.cc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 479df876fbe..bf59c140005 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -54,6 +54,17 @@ static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") { return false; } +static bool IsFCWithPaddingWeights(Node* n) { + bool res = false; + if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" && + n->inputs.size() == 3U && n->outputs.size() == 1U) { + if (n->Op()->HasAttr("padding_weights")) { + res = BOOST_GET_CONST(bool, n->Op()->GetAttr("padding_weights")); + } + } + return res; +} + static bool IsParamOfFC(Node* n, const std::string& param_name) { if (IsInputOfFC(n) && n->inputs.empty() && (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { @@ -255,7 +266,7 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, fc_ops[i] = pattern->NewNode( [=](Node* x) { - if (!IsFCWithAct(x, "relu")) { + if (!IsFCWithAct(x, "relu") || IsFCWithPaddingWeights(x)) { return false; } auto* fc_out_var = x->outputs[0]; -- GitLab