From 747ba3f8273a3aa635ae341a89b61a7382f37854 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Wed, 20 Apr 2022 20:17:21 +0800 Subject: [PATCH] fix adaptive pool pass (#42019) --- .../framework/ir/adaptive_pool2d_convert_global_pass.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc index 7846016d7e7..2625cb48174 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc @@ -67,6 +67,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { std::string name_scope = "adaptive_pool2d_convert_global_pass"; + FusePassBase::Init(name_scope, graph); int num = 0; for (const Node* n : graph->Nodes()) { @@ -77,13 +78,13 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { if (op->HasAttr("global_pooling")) { bool global_pooling = BOOST_GET_CONST(bool, op->GetAttr("global_pooling")); - if (global_pooling) return; + if (global_pooling) continue; } - if (!op->HasAttr("pooling_type")) return; + if (!op->HasAttr("pooling_type")) continue; std::string type = BOOST_GET_CONST(std::string, op->GetAttr("pooling_type")); // adaptive has no effect on max pooling - if (type == "max") return; + if (type == "max") continue; bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive")); std::vector ksize = BOOST_GET_CONST(std::vector, op->GetAttr("ksize")); -- GitLab