未验证 提交 9fd1aad6 编写于 作者: L Leo Chen 提交者: GitHub

Support NHWC in Pool2D, test=develop (#24240)

* support NHWC in Pool2D, test=develop

* add unittest, test=develop

* fix unittest, test=develop

* fix typo, test=develop

* follow comments, test=develop

* refine comments, test=develop
上级 37ae661c
......@@ -771,14 +771,19 @@ class Pool2D(layers.Layer):
ceil_mode (bool, optional): Whether to use the ceil function to calculate output height and width.
False is the default. If it is set to False, the floor function will be used. Default: False.
exclusive (bool, optional): Whether to exclude padding points in average pooling mode. Default: True.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
``[batch_size, input_channels, input_height, input_width]``. When it is `"NHWC"`, the data is
stored in the order of: ``[batch_size, input_height, input_width, input_channels]``
Returns:
None
Raises:
ValueError: If 'pool_type' is not "max" nor "avg"
ValueError: If 'global_pooling' is False and 'pool_size' is -1
ValueError: If 'use_cudnn' is not a bool value.
ValueError: If ``pool_type`` is not "max" nor "avg".
ValueError: If ``global_pooling`` is False and ``pool_size`` is -1.
ValueError: If ``use_cudnn`` is not a bool value.
ValueError: If ``data_format`` is not "NCHW" nor "NHWC".
Examples:
......@@ -806,7 +811,10 @@ class Pool2D(layers.Layer):
global_pooling=False,
use_cudnn=True,
ceil_mode=False,
exclusive=True):
exclusive=True,
data_format="NCHW"):
data_format = data_format.upper() # supprt NHWC, nhwc, etc.
pool_type = pool_type.lower() # supprt max, Max, etc.
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
......@@ -820,6 +828,11 @@ class Pool2D(layers.Layer):
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCHW' or 'NHWC'. Received "
"Attr(data_format): %s." % str(data_format))
super(Pool2D, self).__init__()
self._pool_type = pool_type
......@@ -831,6 +844,7 @@ class Pool2D(layers.Layer):
self._use_cudnn = use_cudnn
self._ceil_mode = ceil_mode
self._exclusive = exclusive
self._data_format = data_format
self._l_type = 'pool2d'
def forward(self, input):
......@@ -839,7 +853,8 @@ class Pool2D(layers.Layer):
'global_pooling', self._global_pooling, 'strides',
self._pool_stride, 'paddings', self._pool_padding,
'use_cudnn', self._use_cudnn, 'ceil_mode', self._ceil_mode,
'use_mkldnn', False, 'exclusive', self._exclusive)
'use_mkldnn', False, 'exclusive', self._exclusive,
'data_format', self._data_format)
return core.ops.pool2d(input, *attrs)
check_variable_and_dtype(
......@@ -856,6 +871,7 @@ class Pool2D(layers.Layer):
"ceil_mode": self._ceil_mode,
"use_mkldnn": False,
"exclusive": self._exclusive,
"data_format": self._data_format,
}
inputs = {"X": [input]}
......
......@@ -1902,7 +1902,7 @@ def pool2d(input,
None by default.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is `true`.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
......
......@@ -1295,6 +1295,78 @@ class TestDygraphPool2DAPIError(unittest.TestCase):
name='x1', shape=[3, 32, 32, 5], dtype="int32")
self.assertRaises(TypeError, pool2d, data2)
def test_data_format_error(self):
with program_guard(Program(), Program()):
# the data_format must be 'NCHW' or 'NHWC'
data1 = np.random.random((3, 32, 32, 5)).astype('float32')
self.assertRaises(
ValueError,
fluid.dygraph.Pool2D,
pool_size=2,
pool_type='max',
pool_stride=1,
global_pooling=False,
data_format='NWHC')
class TestDygraphPool2DAPI(unittest.TestCase):
def test_nhwc(self):
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
x = fluid.dygraph.to_variable(data)
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='max',
pool_stride=1,
pool_padding=[0, 0],
global_pooling=False,
data_format='NHWC')
out1 = pool2d(x)
out2 = pool2D_forward_naive(
data, [2, 2], [1, 1],
paddings=[0, 0],
pool_type='max',
data_format='NHWC')
self.assertTrue(np.allclose(out1.numpy(), out2))
def test_lower_case(self):
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
x = fluid.dygraph.to_variable(data)
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='max',
pool_stride=1,
pool_padding=[0, 0],
global_pooling=False,
data_format='nhwc')
out1 = pool2d(x)
out2 = pool2D_forward_naive(
data, [2, 2], [1, 1],
paddings=[0, 0],
pool_type='max',
data_format='NHWC')
self.assertTrue(np.allclose(out1.numpy(), out2))
def test_upper_case(self):
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
x = fluid.dygraph.to_variable(data)
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='MAX',
pool_stride=1,
pool_padding=[0, 0],
global_pooling=False,
data_format='nhwc')
out1 = pool2d(x)
out2 = pool2D_forward_naive(
data, [2, 2], [1, 1],
paddings=[0, 0],
pool_type='max',
data_format='NHWC')
self.assertTrue(np.allclose(out1.numpy(), out2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册