diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index e798ec5fc1f60e617f5c5c424c5552b994046e46..7cb4702fedcbac9a3dd8fc7bb941735a1cfbe435 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -771,14 +771,19 @@ class Pool2D(layers.Layer): ceil_mode (bool, optional): Whether to use the ceil function to calculate output height and width. False is the default. If it is set to False, the floor function will be used. Default: False. exclusive (bool, optional): Whether to exclude padding points in average pooling mode. Default: True. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + ``[batch_size, input_channels, input_height, input_width]``. When it is `"NHWC"`, the data is + stored in the order of: ``[batch_size, input_height, input_width, input_channels]`` Returns: None Raises: - ValueError: If 'pool_type' is not "max" nor "avg" - ValueError: If 'global_pooling' is False and 'pool_size' is -1 - ValueError: If 'use_cudnn' is not a bool value. + ValueError: If ``pool_type`` is not "max" nor "avg". + ValueError: If ``global_pooling`` is False and ``pool_size`` is -1. + ValueError: If ``use_cudnn`` is not a bool value. + ValueError: If ``data_format`` is not "NCHW" nor "NHWC". Examples: @@ -806,7 +811,10 @@ class Pool2D(layers.Layer): global_pooling=False, use_cudnn=True, ceil_mode=False, - exclusive=True): + exclusive=True, + data_format="NCHW"): + data_format = data_format.upper() # supprt NHWC, nhwc, etc. + pool_type = pool_type.lower() # supprt max, Max, etc. if pool_type not in ["max", "avg"]: raise ValueError( "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", @@ -820,6 +828,11 @@ class Pool2D(layers.Layer): if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + super(Pool2D, self).__init__() self._pool_type = pool_type @@ -831,6 +844,7 @@ class Pool2D(layers.Layer): self._use_cudnn = use_cudnn self._ceil_mode = ceil_mode self._exclusive = exclusive + self._data_format = data_format self._l_type = 'pool2d' def forward(self, input): @@ -839,7 +853,8 @@ class Pool2D(layers.Layer): 'global_pooling', self._global_pooling, 'strides', self._pool_stride, 'paddings', self._pool_padding, 'use_cudnn', self._use_cudnn, 'ceil_mode', self._ceil_mode, - 'use_mkldnn', False, 'exclusive', self._exclusive) + 'use_mkldnn', False, 'exclusive', self._exclusive, + 'data_format', self._data_format) return core.ops.pool2d(input, *attrs) check_variable_and_dtype( @@ -856,6 +871,7 @@ class Pool2D(layers.Layer): "ceil_mode": self._ceil_mode, "use_mkldnn": False, "exclusive": self._exclusive, + "data_format": self._data_format, } inputs = {"X": [input]} diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b333183257e32153722d6750d86e37f85c6916c2..977f9e721fd8285149a62403b060c47d4c320124 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1902,7 +1902,7 @@ def pool2d(input, None by default. exclusive (bool): Whether to exclude padding points in average pooling mode, default is `true`. - data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index e3b79fe9651aa20d5796085f0c0bfbba2ed978fd..a9fdcd55f74cd53824016765fe82a03190f23f89 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -1295,6 +1295,78 @@ class TestDygraphPool2DAPIError(unittest.TestCase): name='x1', shape=[3, 32, 32, 5], dtype="int32") self.assertRaises(TypeError, pool2d, data2) + def test_data_format_error(self): + with program_guard(Program(), Program()): + # the data_format must be 'NCHW' or 'NHWC' + data1 = np.random.random((3, 32, 32, 5)).astype('float32') + self.assertRaises( + ValueError, + fluid.dygraph.Pool2D, + pool_size=2, + pool_type='max', + pool_stride=1, + global_pooling=False, + data_format='NWHC') + + +class TestDygraphPool2DAPI(unittest.TestCase): + def test_nhwc(self): + with fluid.dygraph.guard(): + data = np.random.random((3, 32, 32, 5)).astype('float32') + x = fluid.dygraph.to_variable(data) + pool2d = fluid.dygraph.Pool2D( + pool_size=2, + pool_type='max', + pool_stride=1, + pool_padding=[0, 0], + global_pooling=False, + data_format='NHWC') + out1 = pool2d(x) + out2 = pool2D_forward_naive( + data, [2, 2], [1, 1], + paddings=[0, 0], + pool_type='max', + data_format='NHWC') + self.assertTrue(np.allclose(out1.numpy(), out2)) + + def test_lower_case(self): + with fluid.dygraph.guard(): + data = np.random.random((3, 32, 32, 5)).astype('float32') + x = fluid.dygraph.to_variable(data) + pool2d = fluid.dygraph.Pool2D( + pool_size=2, + pool_type='max', + pool_stride=1, + pool_padding=[0, 0], + global_pooling=False, + data_format='nhwc') + out1 = pool2d(x) + out2 = pool2D_forward_naive( + data, [2, 2], [1, 1], + paddings=[0, 0], + pool_type='max', + data_format='NHWC') + self.assertTrue(np.allclose(out1.numpy(), out2)) + + def test_upper_case(self): + with fluid.dygraph.guard(): + data = np.random.random((3, 32, 32, 5)).astype('float32') + x = fluid.dygraph.to_variable(data) + pool2d = fluid.dygraph.Pool2D( + pool_size=2, + pool_type='MAX', + pool_stride=1, + pool_padding=[0, 0], + global_pooling=False, + data_format='nhwc') + out1 = pool2d(x) + out2 = pool2D_forward_naive( + data, [2, 2], [1, 1], + paddings=[0, 0], + pool_type='max', + data_format='NHWC') + self.assertTrue(np.allclose(out1.numpy(), out2)) + if __name__ == '__main__': unittest.main()