From ccc83bb4e5f2051ff03322a70590848e6a7594b2 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 28 Dec 2018 11:31:21 +0800 Subject: [PATCH] adaptive_pool support pool_size as int. test=develop --- python/paddle/fluid/layers/nn.py | 14 ++------------ python/paddle/fluid/tests/unittests/test_layers.py | 8 ++++++++ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cc1fdbd28..236f1643e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2570,12 +2570,7 @@ def adaptive_pool2d(input, raise ValueError( "invalid setting 'require_index' true when 'pool_type' is 'avg'.") - def _is_list_or_tuple_(data): - return (isinstance(data, list) or isinstance(data, tuple)) - - if not _is_list_or_tuple_(pool_size) or len(pool_size) != 2: - raise ValueError( - "'pool_size' should be a list or tuple with length as 2.") + pool_size = utils.convert_to_list(pool_size, 2, 'pool_size') if pool_type == "max": l_type = 'max_pool2d_with_index' @@ -2671,12 +2666,7 @@ def adaptive_pool3d(input, raise ValueError( "invalid setting 'require_index' true when 'pool_type' is 'avg'.") - def _is_list_or_tuple_(data): - return (isinstance(data, list) or isinstance(data, tuple)) - - if not _is_list_or_tuple_(pool_size) or len(pool_size) != 3: - raise ValueError( - "'pool_size' should be a list or tuple with length as 3.") + pool_size = utils.convert_to_list(pool_size, 3, 'pool_size') if pool_type == "max": l_type = 'max_pool3d_with_index' diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e180822c2..90f5d797a 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -243,6 +243,10 @@ class TestBook(unittest.TestCase): pool, mask = layers.adaptive_pool2d(x, [3, 3], require_index=True) self.assertIsNotNone(pool) self.assertIsNotNone(mask) + self.assertIsNotNone(layers.adaptive_pool2d(x, 3, pool_type='avg')) + pool, mask = layers.adaptive_pool2d(x, 3, require_index=True) + self.assertIsNotNone(pool) + self.assertIsNotNone(mask) def test_adaptive_pool3d(self): program = Program() @@ -255,6 +259,10 @@ class TestBook(unittest.TestCase): x, [3, 3, 3], require_index=True) self.assertIsNotNone(pool) self.assertIsNotNone(mask) + self.assertIsNotNone(layers.adaptive_pool3d(x, 3, pool_type='avg')) + pool, mask = layers.adaptive_pool3d(x, 3, require_index=True) + self.assertIsNotNone(pool) + self.assertIsNotNone(mask) def test_lstm_unit(self): program = Program() -- GitLab