diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 479df876fbe007119c55261dd149bd515b0cd117..bf59c140005167e3be342b4039d2b13e5bddf1c6 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -54,6 +54,17 @@ static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") { return false; } +static bool IsFCWithPaddingWeights(Node* n) { + bool res = false; + if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" && + n->inputs.size() == 3U && n->outputs.size() == 1U) { + if (n->Op()->HasAttr("padding_weights")) { + res = BOOST_GET_CONST(bool, n->Op()->GetAttr("padding_weights")); + } + } + return res; +} + static bool IsParamOfFC(Node* n, const std::string& param_name) { if (IsInputOfFC(n) && n->inputs.empty() && (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { @@ -255,7 +266,7 @@ void BuildRepeatedFCReluPattern(PDPattern* pattern, fc_ops[i] = pattern->NewNode( [=](Node* x) { - if (!IsFCWithAct(x, "relu")) { + if (!IsFCWithAct(x, "relu") || IsFCWithPaddingWeights(x)) { return false; } auto* fc_out_var = x->outputs[0];