未验证 提交 c1c5c1fc 编写于 作者: W wenbin 提交者: GitHub

adaptive pool2d pass fix (#39600)

* first commit

* teller fix

* bug fix

* enable for pool2d only

* fix global_pooling issue

* pooling_type

* fix test
上级 db43b541
...@@ -72,7 +72,18 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { ...@@ -72,7 +72,18 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
if (op->HasAttr("adaptive") && op->HasAttr("ksize")) { if (op->Type() == "pool2d" && op->HasAttr("adaptive") &&
op->HasAttr("ksize")) {
if (op->HasAttr("global_pooling")) {
bool global_pooling =
BOOST_GET_CONST(bool, op->GetAttr("global_pooling"));
if (global_pooling) return;
}
if (!op->HasAttr("pooling_type")) return;
std::string type =
BOOST_GET_CONST(std::string, op->GetAttr("pooling_type"));
// adaptive has no effect on max pooling
if (type == "max") return;
bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive")); bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive"));
std::vector<int> ksize = std::vector<int> ksize =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize")); BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize"));
......
...@@ -29,6 +29,8 @@ TEST(AdaptivePool2dConvertGlobalPass, basic) { ...@@ -29,6 +29,8 @@ TEST(AdaptivePool2dConvertGlobalPass, basic) {
AttributeMap attrs; AttributeMap attrs;
attrs["adaptive"] = true; attrs["adaptive"] = true;
attrs["ksize"] = std::vector<int>{1, 1}; attrs["ksize"] = std::vector<int>{1, 1};
attrs["pooling_type"] =
std::string("avg"); // adaptive has no effect on max pooling
layers.pool2d(x, false, &attrs); layers.pool2d(x, false, &attrs);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
......
...@@ -225,6 +225,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -225,6 +225,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< desc.Output("Out").size(); << desc.Output("Out").size();
return false; return false;
} }
if (desc.HasAttr("data_format")) {
std::string data_format =
BOOST_GET_CONST(std::string, desc.GetAttr("data_format"));
if (data_format == "NHWC" || data_format == "NDHWC") {
return false;
}
}
if (!desc.HasAttr("pooling_type")) { if (!desc.HasAttr("pooling_type")) {
return false; return false;
} else { } else {
......
...@@ -42,10 +42,14 @@ class TestAdaptivePool2dConvertGlobalPass(PassAutoScanTest): ...@@ -42,10 +42,14 @@ class TestAdaptivePool2dConvertGlobalPass(PassAutoScanTest):
st.integers( st.integers(
min_value=1, max_value=4), min_size=2, max_size=2)) min_value=1, max_value=4), min_size=2, max_size=2))
paddings = [0, 0] # only 0 0 is right paddings = draw(
st.lists(
st.integers(
min_value=1, max_value=4), min_size=2, max_size=2))
ceil_mode = draw(st.booleans()) ceil_mode = draw(st.booleans())
exclusive = draw(st.booleans()) exclusive = draw(st.booleans())
global_pooling = False #only false is right global_pooling = draw(st.booleans())
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VAILD"])) padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VAILD"]))
pool_op = OpConfig( pool_op = OpConfig(
...@@ -83,29 +87,6 @@ class TestAdaptivePool2dConvertGlobalPass(PassAutoScanTest): ...@@ -83,29 +87,6 @@ class TestAdaptivePool2dConvertGlobalPass(PassAutoScanTest):
use_calib_mode=False) use_calib_mode=False)
yield config, ['pool2d'], (1e-5, 1e-5) yield config, ['pool2d'], (1e-5, 1e-5)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs["pooling_type"] == "max":
x_shape = list(program_config.inputs["input_data"].shape)
if x_shape[-1] != 1 or x_shape[-2] != 1:
return True
return False
def teller2(program_config, predictor_config):
if program_config.ops[0].attrs["padding_algorithm"] == "SAME":
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"max pooling has diff if H or W is not equals to 1", )
self.add_ignore_check_case(
teller2,
IgnoreReasons.PASS_ACCURACY_ERROR,
"output has wrong result if padding_algorithm equals to SAME", )
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, quant=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册