diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc index 7846016d7e7b290f5b2bc3b2f35242df230dcc83..2625cb48174b88e01e6d09efa5e869642c958ffb 100644 --- a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc @@ -67,6 +67,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { std::string name_scope = "adaptive_pool2d_convert_global_pass"; + FusePassBase::Init(name_scope, graph); int num = 0; for (const Node* n : graph->Nodes()) { @@ -77,13 +78,13 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { if (op->HasAttr("global_pooling")) { bool global_pooling = BOOST_GET_CONST(bool, op->GetAttr("global_pooling")); - if (global_pooling) return; + if (global_pooling) continue; } - if (!op->HasAttr("pooling_type")) return; + if (!op->HasAttr("pooling_type")) continue; std::string type = BOOST_GET_CONST(std::string, op->GetAttr("pooling_type")); // adaptive has no effect on max pooling - if (type == "max") return; + if (type == "max") continue; bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive")); std::vector ksize = BOOST_GET_CONST(std::vector, op->GetAttr("ksize"));