未验证 提交 eebb82a2 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #15097 from heavengate/adaptive_pool_ksize

adaptive_pool support pool_size as int. test=develop
......@@ -2585,12 +2585,7 @@ def adaptive_pool2d(input,
raise ValueError(
"invalid setting 'require_index' true when 'pool_type' is 'avg'.")
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(pool_size) or len(pool_size) != 2:
raise ValueError(
"'pool_size' should be a list or tuple with length as 2.")
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
if pool_type == "max":
l_type = 'max_pool2d_with_index'
......@@ -2686,12 +2681,7 @@ def adaptive_pool3d(input,
raise ValueError(
"invalid setting 'require_index' true when 'pool_type' is 'avg'.")
def _is_list_or_tuple_(data):
return (isinstance(data, list) or isinstance(data, tuple))
if not _is_list_or_tuple_(pool_size) or len(pool_size) != 3:
raise ValueError(
"'pool_size' should be a list or tuple with length as 3.")
pool_size = utils.convert_to_list(pool_size, 3, 'pool_size')
if pool_type == "max":
l_type = 'max_pool3d_with_index'
......
......@@ -243,6 +243,10 @@ class TestBook(unittest.TestCase):
pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
self.assertIsNotNone(layers.adaptive_pool2d(x, 3, pool_type='avg'))
pool, mask = layers.adaptive_pool2d(x, 3, require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
def test_adaptive_pool3d(self):
program = Program()
......@@ -255,6 +259,10 @@ class TestBook(unittest.TestCase):
x, [3, 3, 3], require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
self.assertIsNotNone(layers.adaptive_pool3d(x, 3, pool_type='avg'))
pool, mask = layers.adaptive_pool3d(x, 3, require_index=True)
self.assertIsNotNone(pool)
self.assertIsNotNone(mask)
def test_lstm_unit(self):
program = Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册