From c4a417f5a74cf602f2af75d4a5c7a96a60e655c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 10 Jun 2021 10:24:17 +0800 Subject: [PATCH] fix the bug in repeated_fc_relu_fuse_pass.test=develop (#33386) (#33431) --- .../fluid/framework/ir/repeated_fc_relu_fuse_pass.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index bf59c140005..4c87b63625c 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -66,9 +66,13 @@ static bool IsFCWithPaddingWeights(Node* n) { } static bool IsParamOfFC(Node* n, const std::string& param_name) { - if (IsInputOfFC(n) && n->inputs.empty() && - (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { - return true; + if (IsInputOfFC(n) && n->inputs.empty()) { + for (auto* out : n->outputs) { + if (out->Op()->Type() == "fc" && + n->Name() == out->Op()->Input(param_name)[0]) { + return true; + } + } } return false; } -- GitLab