未验证 提交 747ba3f8 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix adaptive pool pass (#42019)

上级 d67abac6
......@@ -67,6 +67,7 @@ AdaptivePool2dConvertGlobalPass::AdaptivePool2dConvertGlobalPass() {
void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
std::string name_scope = "adaptive_pool2d_convert_global_pass";
FusePassBase::Init(name_scope, graph);
int num = 0;
for (const Node* n : graph->Nodes()) {
......@@ -77,13 +78,13 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
if (op->HasAttr("global_pooling")) {
bool global_pooling =
BOOST_GET_CONST(bool, op->GetAttr("global_pooling"));
if (global_pooling) return;
if (global_pooling) continue;
}
if (!op->HasAttr("pooling_type")) return;
if (!op->HasAttr("pooling_type")) continue;
std::string type =
BOOST_GET_CONST(std::string, op->GetAttr("pooling_type"));
// adaptive has no effect on max pooling
if (type == "max") return;
if (type == "max") continue;
bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive"));
std::vector<int> ksize =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册