diff --git a/python/paddle/fluid/tests/unittests/test_pool1d_api.py b/python/paddle/fluid/tests/unittests/test_pool1d_api.py index 1c05b96f1fc61234028e940f6403ae08a0186027..25216175d59935535a352b02afc3c8f371cedd63 100644 --- a/python/paddle/fluid/tests/unittests/test_pool1d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool1d_api.py @@ -143,6 +143,27 @@ class TestPool1d_API(unittest.TestCase): result = avg_pool1d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_avg_dygraph_padding_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.avg_pool1d( + input, + kernel_size=2, + stride=2, + padding=[1], + count_include_pad=True) + + result_np = avg_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[1], exclusive=False) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool1d_dg = paddle.nn.AvgPool1d( + kernel_size=2, stride=None, padding=1, count_include_pad=True) + result = avg_pool1d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_static_results(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32") diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_api.py b/python/paddle/fluid/tests/unittests/test_pool2d_api.py index 93a2be6de342efc4e8284e7c352137d0a3a1bcb9..91faf78418b0d3a92a3cb6a167b6024b1beb3898 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive +from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive import unittest from op_test import OpTest import numpy as np @@ -68,6 +68,47 @@ class TestPool2d_API(unittest.TestCase): result = avg_pool2d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_avg_dygraph_padding_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool2d( + input, kernel_size=2, stride=2, padding=1, ceil_mode=False) + + result_np = avg_pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[1, 1], + ceil_mode=False, + exclusive=False) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool2d_dg = paddle.nn.layer.AvgPool2d( + kernel_size=2, stride=2, padding=1, ceil_mode=False) + result = avg_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_avg_dygraph_ceilmode_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool2d( + input, kernel_size=2, stride=2, padding=0, ceil_mode=True) + + result_np = avg_pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + ceil_mode=True) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool2d_dg = paddle.nn.layer.AvgPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + result = avg_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_static_results(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data( @@ -108,6 +149,70 @@ class TestPool2d_API(unittest.TestCase): result = max_pool2d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_dygraph_nhwc_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable( + np.transpose(input_np, [0, 2, 3, 1])) + result = max_pool2d( + input, + kernel_size=2, + stride=2, + padding=0, + return_indices=False, + data_format="NHWC") + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='max') + self.assertTrue( + np.allclose( + np.transpose(result.numpy(), [0, 3, 1, 2]), result_np)) + + def check_max_dygraph_padding_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = max_pool2d( + input, kernel_size=2, stride=2, padding=1, ceil_mode=False) + + result_np = max_pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[1, 1], + ceil_mode=False, + exclusive=False) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool2d_dg = paddle.nn.layer.MaxPool2d( + kernel_size=2, stride=2, padding=1, ceil_mode=False) + result = max_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_dygraph_ceilmode_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = max_pool2d( + input, kernel_size=2, stride=2, padding=0, ceil_mode=True) + + result_np = max_pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + ceil_mode=True) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool2d_dg = paddle.nn.layer.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + result = max_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_dygraph_stride_is_none(self, place): with fluid.dygraph.guard(place): input_np = np.random.random([2, 3, 32, 32]).astype("float32") @@ -215,6 +320,9 @@ class TestPool2d_API(unittest.TestCase): self.check_avg_dygraph_stride_is_none(place) self.check_max_dygraph_padding(place) self.check_avg_divisor(place) + self.check_max_dygraph_padding_results(place) + self.check_max_dygraph_ceilmode_results(place) + self.check_max_dygraph_nhwc_results(place) class TestPool2dError_API(unittest.TestCase): @@ -370,6 +478,22 @@ class TestPool2dError_API(unittest.TestCase): self.assertRaises(ValueError, run8) + def run9(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = max_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=0, + ceil_mode=False, + data_format='NHWC', + return_indices=True) + + self.assertRaises(ValueError, run9) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_api.py b/python/paddle/fluid/tests/unittests/test_pool3d_api.py index cc078e9aae7aafe55e937b80270dd012fd64ff70..a77f1cdd57d7bade92e2a4f914dc3d91624d4845 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_api.py @@ -22,7 +22,7 @@ import paddle.fluid.core as core from op_test import OpTest import paddle.fluid as fluid from paddle.nn.functional import avg_pool3d, max_pool3d -from test_pool3d_op import adaptive_start_index, adaptive_end_index, pool3D_forward_naive +from test_pool3d_op import adaptive_start_index, adaptive_end_index, pool3D_forward_naive, avg_pool3D_forward_naive, max_pool3D_forward_naive class TestPool3d_API(unittest.TestCase): @@ -73,6 +73,58 @@ class TestPool3d_API(unittest.TestCase): result = avg_pool3d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_avg_dygraph_padding_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool3d( + input, + kernel_size=2, + stride=2, + padding=1, + ceil_mode=False, + count_include_pad=True) + + result_np = avg_pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[1, 1, 1], + ceil_mode=False, + exclusive=False) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool3d_dg = paddle.nn.layer.AvgPool3d( + kernel_size=2, + stride=None, + padding=1, + ceil_mode=False, + count_include_pad=True) + result = avg_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_avg_dygraph_ceilmode_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool3d( + input, kernel_size=2, stride=2, padding=0, ceil_mode=True) + + result_np = avg_pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + ceil_mode=True) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool3d_dg = paddle.nn.layer.AvgPool3d( + kernel_size=2, stride=None, padding=0, ceil_mode=True) + result = avg_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_static_results(self, place): with fluid.program_guard(fluid.Program(), fluid.Program()): input = fluid.data( @@ -112,6 +164,74 @@ class TestPool3d_API(unittest.TestCase): result = max_pool3d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_dygraph_ndhwc_results(self, place): + print("run ndchw max pool3d") + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable( + np.transpose(input_np, [0, 2, 3, 4, 1])) + result = max_pool3d( + input, + kernel_size=2, + stride=2, + padding=0, + data_format="NDHWC", + return_indices=False) + + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='max') + + self.assertTrue( + np.allclose( + np.transpose(result.numpy(), [0, 4, 1, 2, 3]), result_np)) + + def check_max_dygraph_ceilmode_results(self, place): + print("run ceil mode max pool3d") + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = max_pool3d( + input, kernel_size=2, stride=2, padding=0, ceil_mode=True) + + result_np = max_pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + ceil_mode=True) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool3d_dg = paddle.nn.layer.MaxPool3d( + kernel_size=2, stride=None, padding=0, ceil_mode=True) + result = max_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_dygraph_padding_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = max_pool3d( + input, kernel_size=2, stride=2, padding=1, ceil_mode=False) + + result_np = max_pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[1, 1, 1], + ceil_mode=False) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool3d_dg = paddle.nn.layer.MaxPool3d( + kernel_size=2, stride=None, padding=1, ceil_mode=False) + result = max_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + def check_max_dygraph_stride_is_none(self, place): with fluid.dygraph.guard(place): input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") @@ -205,6 +325,8 @@ class TestPool3d_API(unittest.TestCase): self.check_max_dygraph_stride_is_none(place) self.check_max_dygraph_padding(place) self.check_avg_divisor(place) + self.check_max_dygraph_ndhwc_results(place) + self.check_max_dygraph_ceilmode_results(place) class TestPool3dError_API(unittest.TestCase): @@ -336,6 +458,21 @@ class TestPool3dError_API(unittest.TestCase): self.assertRaises(ValueError, run9) + def run10(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = max_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding=0, + data_format='NDHWC', + return_indices=True) + + self.assertRaises(ValueError, run10) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index c8790a75901fd5d9a38862158246e3756dc575c4..8d22748c4d8765b35775ffe65bbaa0aaa69a35cf 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -168,7 +168,7 @@ def avg_pool1d(x, count_include_pad=True, ceil_mode=False, name=None): - """ + """ This API implements average pooling 1d operation, See more details in :ref:`api_nn_pooling_AvgPool1d` . @@ -280,7 +280,7 @@ def avg_pool2d(x, """ This API implements average pooling 2d operation. See more details in :ref:`api_nn_pooling_AvgPool2d` . - + Args: x (Tensor): The input tensor of pooling operator which is a 4-D tensor with shape [N, C, H, W]. The format of input tensor is `"NCHW"` or @@ -640,7 +640,7 @@ def max_pool2d(x, 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape - return_indices (bool): Whether to return the max indices along with the outputs. + return_indices (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. @@ -690,15 +690,30 @@ def max_pool2d(x, padding, padding_algorithm = _update_padding_nd( padding, num_dims=2, channel_last=channel_last, ceil_mode=ceil_mode) + if data_format == "NHWC" and return_indices: + raise ValueError( + "When setting return_indices to true, data_format must be set to NCHW in API:max_pool2d" + ) + if in_dygraph_mode(): - output = core.ops.max_pool2d_with_index( - x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, - 'paddings', padding, 'padding_algorithm', padding_algorithm, - 'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, - 'exclusive', True, 'data_format', data_format) - return output if return_indices else output[0] + if data_format == "NCHW": + output = core.ops.max_pool2d_with_index( + x, 'ksize', kernel_size, 'global_pooling', False, 'strides', + stride, 'paddings', padding, 'padding_algorithm', + padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, + 'use_mkldnn', False, 'exclusive', True, 'data_format', + data_format) + return output if return_indices else output[0] + elif data_format == "NHWC" and not return_indices: + output = core.ops.pool2d( + x, 'pooling_type', 'max', 'ksize', kernel_size, + 'global_pooling', False, 'padding_algorithm', padding_algorithm, + 'strides', stride, 'paddings', padding, 'use_cudnn', True, + 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, + 'data_format', data_format) + return output - op_type = 'max_pool2d_with_index' + op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "max_pool2d" helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype() pool_out = helper.create_variable_for_type_inference(dtype) @@ -739,7 +754,7 @@ def max_pool3d(x, See more details in :ref:`api_nn_pooling_MaxPool3d` . Args: x (Tensor): The input tensor of pooling operator, which is a 5-D tensor with - shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` or `"NDHWC"`, where N represents batch size, C represents the number of channels, D, H and W represent the depth, height and width of the feature respectively. + shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` or `"NDHWC"`, where N represents batch size, C represents the number of channels, D, H and W represent the depth, height and width of the feature respectively. kernel_size (int|list|tuple): The pool kernel size. If the kernel size is a tuple or list, it must contain three integers, (kernel_size_Depth, kernel_size_Height, kernel_size_Width). @@ -755,7 +770,7 @@ def max_pool3d(x, 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): ${ceil_mode_comment} - return_indices (bool): Whether to return the max indices along with the outputs. + return_indices (bool): Whether to return the max indices along with the outputs. Default False. Only support "NDCHW" data_format. data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. @@ -801,15 +816,30 @@ def max_pool3d(x, padding, padding_algorithm = _update_padding_nd( padding, 3, channel_last=channel_last, ceil_mode=ceil_mode) + if data_format == "NDHWC" and return_indices: + raise ValueError( + "When setting return_indices to true, data_format must be set to NCDHW in API:max_pool3d" + ) + if in_dygraph_mode(): - output = core.ops.max_pool3d_with_index( - x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', stride, - 'paddings', padding, 'global_pooling', False, 'padding_algorithm', - padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, - 'use_mkldnn', False, 'exclusive', True, 'data_format', data_format) - return output if return_indices else output[0] + if data_format == "NCDHW": + output = core.ops.max_pool3d_with_index( + x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', + stride, 'paddings', padding, 'global_pooling', False, + 'padding_algorithm', padding_algorithm, 'use_cudnn', True, + 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, + 'data_format', data_format) + return output if return_indices else output[0] + elif data_format == "NDHWC" and not return_indices: + output = core.ops.pool3d( + x, 'pooling_type', 'max', 'ksize', kernel_size, + 'global_pooling', False, 'padding_algorithm', padding_algorithm, + 'strides', stride, 'paddings', padding, 'use_cudnn', True, + 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, + 'data_format', data_format) + return output - op_type = "max_pool3d_with_index" + op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "max_pool3d" helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype() pool_out = helper.create_variable_for_type_inference(dtype) @@ -841,7 +871,7 @@ def adaptive_avg_pool1d(x, output_size, name=None): """ This API implements adaptive average pooling 1d operation. See more details in :ref:`api_nn_pooling_AdaptiveAvgPool1d` . - + Args: x (Tensor): The input tensor of pooling operator, which is a 3-D tensor with shape [N, C, L]. The format of input tensor is NCL, diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 6f6b567849732ff889db4507708758cd8eeab2a8..b31d7cb31968899ec3398cda04e664e9d6cc887d 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -850,7 +850,7 @@ class AdaptiveMaxPool1d(layers.Layer): lend &= ceil((i + 1) * L_{in} / L_{out}) - Output(i) &= max(Input[lstart:lend])} + Output(i) &= max(Input[lstart:lend]) Args: output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, @@ -932,7 +932,7 @@ class AdaptiveMaxPool2d(layers.Layer): Shape: x (Tensor): The input tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type can be float32, float64. output (Tensor): The output tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type is same as input x. - + Returns: A callable object of AdaptiveMaxPool2d. Examples: @@ -1032,7 +1032,7 @@ class AdaptiveMaxPool3d(layers.Layer): pool, indices = paddle.nn.AdaptiveMaxPool3d(output_size=3, return_indices=True) out = pool(x) # out shape: [2, 3, 4, 4, 4], indices shape: [2, 3, 4, 4, 4] - + """ def __init__(self, output_size, return_indices=False, name=None):