未验证 提交 eebb82a2 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #15097 from heavengate/adaptive_pool_ksize

adaptive_pool support pool_size as int. test=develop
...@@ -2585,12 +2585,7 @@ def adaptive_pool2d(input, ...@@ -2585,12 +2585,7 @@ def adaptive_pool2d(input,
raise ValueError( raise ValueError(
"invalid setting 'require_index' true when 'pool_type' is 'avg'.") "invalid setting 'require_index' true when 'pool_type' is 'avg'.")
def _is_list_or_tuple_(data): pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(pool_size) or len(pool_size) != 2:
raise ValueError(
"'pool_size' should be a list or tuple with length as 2.")
if pool_type == "max": if pool_type == "max":
l_type = 'max_pool2d_with_index' l_type = 'max_pool2d_with_index'
...@@ -2686,12 +2681,7 @@ def adaptive_pool3d(input, ...@@ -2686,12 +2681,7 @@ def adaptive_pool3d(input,
raise ValueError( raise ValueError(
"invalid setting 'require_index' true when 'pool_type' is 'avg'.") "invalid setting 'require_index' true when 'pool_type' is 'avg'.")
def _is_list_or_tuple_(data): pool_size = utils.convert_to_list(pool_size, 3, 'pool_size')
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(pool_size) or len(pool_size) != 3:
raise ValueError(
"'pool_size' should be a list or tuple with length as 3.")
if pool_type == "max": if pool_type == "max":
l_type = 'max_pool3d_with_index' l_type = 'max_pool3d_with_index'
......
...@@ -243,6 +243,10 @@ class TestBook(unittest.TestCase): ...@@ -243,6 +243,10 @@ class TestBook(unittest.TestCase):
pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True) pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True)
self.assertIsNotNone(pool) self.assertIsNotNone(pool)
self.assertIsNotNone(mask) self.assertIsNotNone(mask)
self.assertIsNotNone(layers.adaptive_pool2d(x, 3, pool_type='avg'))
pool, mask = layers.adaptive_pool2d(x, 3, require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
def test_adaptive_pool3d(self): def test_adaptive_pool3d(self):
program = Program() program = Program()
...@@ -255,6 +259,10 @@ class TestBook(unittest.TestCase): ...@@ -255,6 +259,10 @@ class TestBook(unittest.TestCase):
x, [3, 3, 3], require_index=True) x, [3, 3, 3], require_index=True)
self.assertIsNotNone(pool) self.assertIsNotNone(pool)
self.assertIsNotNone(mask) self.assertIsNotNone(mask)
self.assertIsNotNone(layers.adaptive_pool3d(x, 3, pool_type='avg'))
pool, mask = layers.adaptive_pool3d(x, 3, require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
def test_lstm_unit(self): def test_lstm_unit(self):
program = Program() program = Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册