From 79ce2a6c31d74500a1b5fcadd9dc6c0d3debf4b6 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 30 Apr 2021 09:23:00 +0800 Subject: [PATCH] skip fuse repeated fc when the fc with weight padding (#32648) (#32680) --- .../framework/ir/repeated_fc_relu_fuse_pass.cc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 479df876fb..bf59c14000 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -54,6 +54,17 @@ static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") { return false; } +static bool IsFCWithPaddingWeights(Node* n) { + bool res = false; + if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" && + n->inputs.size() == 3U && n->outputs.size() == 1U) { + if (n->Op()->HasAttr("padding_weights")) { + res = BOOST_GET_CONST(bool, n->Op()->GetAttr("padding_weights")); + } + } + return res; +} + static bool IsParamOfFC(Node* n, const std::string& param_name) { if (IsInputOfFC(n) && n->inputs.empty() && (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { @@ -255,7 +266,7 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, fc_ops[i] = pattern->NewNode( [=](Node* x) { - if (!IsFCWithAct(x, "relu")) { + if (!IsFCWithAct(x, "relu") || IsFCWithPaddingWeights(x)) { return false; } auto* fc_out_var = x->outputs[0]; -- GitLab