diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index bf59c140005167e3be342b4039d2b13e5bddf1c6..4c87b63625c1f69c09588c5bb8483ab03616f153 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -66,9 +66,13 @@ static bool IsFCWithPaddingWeights(Node* n) { } static bool IsParamOfFC(Node* n, const std::string& param_name) { - if (IsInputOfFC(n) && n->inputs.empty() && - (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { - return true; + if (IsInputOfFC(n) && n->inputs.empty()) { + for (auto* out : n->outputs) { + if (out->Op()->Type() == "fc" && + n->Name() == out->Op()->Input(param_name)[0]) { + return true; + } + } } return false; }