未验证 提交 747ba3f8 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix adaptive pool pass (#42019)

上级 d67abac6
...@@ -67,6 +67,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() { ...@@ -67,6 +67,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() {
void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
std::string name_scope = "adaptive_pool2d_convert_global_pass"; std::string name_scope = "adaptive_pool2d_convert_global_pass";
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
int num = 0; int num = 0;
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
...@@ -77,13 +78,13 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { ...@@ -77,13 +78,13 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
if (op->HasAttr("global_pooling")) { if (op->HasAttr("global_pooling")) {
bool global_pooling = bool global_pooling =
BOOST_GET_CONST(bool, op->GetAttr("global_pooling")); BOOST_GET_CONST(bool, op->GetAttr("global_pooling"));
if (global_pooling) return; if (global_pooling) continue;
} }
if (!op->HasAttr("pooling_type")) return; if (!op->HasAttr("pooling_type")) continue;
std::string type = std::string type =
BOOST_GET_CONST(std::string, op->GetAttr("pooling_type")); BOOST_GET_CONST(std::string, op->GetAttr("pooling_type"));
// adaptive has no effect on max pooling // adaptive has no effect on max pooling
if (type == "max") return; if (type == "max") continue;
bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive")); bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive"));
std::vector<int> ksize = std::vector<int> ksize =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize")); BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册