diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index 3561405ae090bd99ec03d25d6abe0f01f8b13e7a..8bb87198d8069de954619434599568e38126d093 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -146,6 +146,10 @@ class TestLayerPrint(unittest.TestCase): 'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)' ) + module = nn.ZeroPad2D(padding=[1, 0, 1, 2]) + self.assertEqual( + str(module), 'ZeroPad2D(padding=[1, 0, 1, 2], data_format=NCHW)') + module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant') self.assertEqual( str(module), diff --git a/python/paddle/fluid/tests/unittests/test_zeropad2d.py b/python/paddle/fluid/tests/unittests/test_zeropad2d.py new file mode 100644 index 0000000000000000000000000000000000000000..2849caf17c62d8f0625b39f04e32b84b319c0504 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_zeropad2d.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle import to_tensor +from paddle.nn.functional import zeropad2d +from paddle.nn import ZeroPad2D + + +class TestZeroPad2dAPIError(unittest.TestCase): + """ + test paddle.zeropad2d error. + """ + + def setUp(self): + """ + unsupport dtypes + """ + self.shape = [4, 3, 224, 224] + self.unsupport_dtypes = ['bool', 'int8'] + + def test_unsupport_dtypes(self): + """ + test unsupport dtypes. + """ + for dtype in self.unsupport_dtypes: + pad = 2 + x = np.random.randint(-255, 255, size=self.shape) + x_tensor = to_tensor(x).astype(dtype) + self.assertRaises(TypeError, zeropad2d, x=x_tensor, padding=pad) + + +class TestZeroPad2dAPI(unittest.TestCase): + """ + test paddle.zeropad2d + """ + + def setUp(self): + """ + support dtypes + """ + self.shape = [4, 3, 224, 224] + self.support_dtypes = ['float32', 'float64', 'int32', 'int64'] + + def test_support_dtypes(self): + """ + test support types + """ + for dtype in self.support_dtypes: + pad = 2 + x = np.random.randint(-255, 255, size=self.shape).astype(dtype) + expect_res = np.pad(x, [[0, 0], [0, 0], [pad, pad], [pad, pad]]) + + x_tensor = to_tensor(x).astype(dtype) + ret_res = zeropad2d(x_tensor, [pad, pad, pad, pad]).numpy() + self.assertTrue(np.allclose(expect_res, ret_res)) + + def test_support_pad2(self): + """ + test the type of 'pad' is list. + """ + pad = [1, 2, 3, 4] + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, [[0, 0], [0, 0], [pad[2], pad[3]], [pad[0], pad[1]]]) + + x_tensor = to_tensor(x) + ret_res = zeropad2d(x_tensor, pad).numpy() + self.assertTrue(np.allclose(expect_res, ret_res)) + + def test_support_pad3(self): + """ + test the type of 'pad' is tuple. + """ + pad = (1, 2, 3, 4) + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, [[0, 0], [0, 0], [pad[2], pad[3]], [pad[0], pad[1]]]) + + x_tensor = to_tensor(x) + ret_res = zeropad2d(x_tensor, pad).numpy() + self.assertTrue(np.allclose(expect_res, ret_res)) + + def test_support_pad4(self): + """ + test the type of 'pad' is paddle.Tensor. + """ + pad = [1, 2, 3, 4] + x = np.random.randint(-255, 255, size=self.shape) + expect_res = np.pad( + x, [[0, 0], [0, 0], [pad[2], pad[3]], [pad[0], pad[1]]]) + + x_tensor = to_tensor(x) + pad_tensor = to_tensor(pad, dtype='int32') + ret_res = zeropad2d(x_tensor, pad_tensor).numpy() + self.assertTrue(np.allclose(expect_res, ret_res)) + + +class TestZeroPad2DLayer(unittest.TestCase): + """ + test nn.ZeroPad2D + """ + + def setUp(self): + self.shape = [4, 3, 224, 224] + self.pad = [2, 2, 4, 1] + self.padLayer = ZeroPad2D(padding=self.pad) + self.x = np.random.randint(-255, 255, size=self.shape) + self.expect_res = np.pad(self.x, + [[0, 0], [0, 0], [self.pad[2], self.pad[3]], + [self.pad[0], self.pad[1]]]) + + def test_layer(self): + self.assertTrue( + np.allclose( + zeropad2d(to_tensor(self.x), self.pad).numpy(), + self.padLayer(to_tensor(self.x)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 064052c07695de8de8a5e97a357d03c85644934e..1abe74e9783dc4a6eb1767f547a75a78815e68c0 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -52,6 +52,7 @@ from .layer.activation import LogSoftmax # noqa: F401 from .layer.activation import Maxout # noqa: F401 from .layer.common import Pad1D # noqa: F401 from .layer.common import Pad2D # noqa: F401 +from .layer.common import ZeroPad2D # noqa: F401 from .layer.common import Pad3D # noqa: F401 from .layer.common import CosineSimilarity # noqa: F401 from .layer.common import Embedding # noqa: F401 @@ -293,5 +294,6 @@ __all__ = [ #noqa 'PixelShuffle', 'ELU', 'ReLU6', - 'LayerDict' + 'LayerDict', + 'ZeroPad2D' ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 1af53e0826be878d8734fa602160881225eff13d..3e3bd6397c072a196db78b15c32545b4ece41ebd 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -52,6 +52,7 @@ from .common import dropout3d # noqa: F401 from .common import alpha_dropout # noqa: F401 from .common import label_smooth # noqa: F401 from .common import pad # noqa: F401 +from .common import zeropad2d # noqa: F401 from .common import cosine_similarity # noqa: F401 from .common import unfold # noqa: F401 from .common import interpolate # noqa: F401 @@ -162,6 +163,7 @@ __all__ = [ #noqa 'label_smooth', 'linear', 'pad', + 'zeropad2d', 'unfold', 'interpolate', 'upsample', diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index ef08982a6ff11257aa8022c5b10dc98365513386..217f8cd4125518a5a5f85cb6138f16e3046fff0d 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1365,6 +1365,46 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): return out +def zeropad2d(x, padding, data_format="NCHW", name=None): + """ + Pads the input tensor boundaries with zero according to 'pad'. + + Args: + x(Tensor): The input tensor with data type float16/float32/float64/int32/int64. + padding(int | Tensor | List[int] | Tuple[int]): The padding size with data type int. + The input dimension should be 4 and pad has the form (pad_left, pad_right, + pad_top, pad_bottom). + data_format(str): An string from: "NHWC", "NCHW". Specify the data format of + the input data. Default: "NCHW". + name(str, optional): The default value is None. Normally there is no need for user + to set this property. + + Returns:Tensor,padded with 0 according to pad and data type is same as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + import paddle.nn.functional as F + + x_shape = (1, 1, 2, 3) + x = paddle.arange(np.prod(x_shape), dtype="float32").reshape(x_shape) + 1 + y = F.zeropad2d(x, [1, 2, 1, 1]) + # [[[[0. 0. 0. 0. 0. 0.] + # [0. 1. 2. 3. 0. 0.] + # [0. 4. 5. 6. 0. 0.] + # [0. 0. 0. 0. 0. 0.]]]] + """ + + return pad(x, + pad=padding, + mode='constant', + value=0, + data_format=data_format, + name=name) + + def cosine_similarity(x1, x2, axis=1, eps=1e-8): """ Compute cosine similarity between x1 and x2 along axis. diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index eb7535b16c6e1e7e19e3a1260d506e6996379aa3..a65f9912d593915cae64f20392f8aeb46e15d16a 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -29,6 +29,7 @@ from .activation import LogSoftmax # noqa: F401 from .common import Bilinear # noqa: F401 from .common import Pad1D # noqa: F401 from .common import Pad2D # noqa: F401 +from .common import ZeroPad2D # noqa: F401 from .common import Pad3D # noqa: F401 from .common import CosineSimilarity # noqa: F401 from .common import Embedding # noqa: F401 diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index d5557bd9ea4e71bd02a2eea850f9d5b61625a9ce..1069a24be21f883ef9a232593479c123a321e4b4 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1108,6 +1108,72 @@ class Pad2D(Layer): self._pad, self._mode, self._value, self._data_format, name_str) +class ZeroPad2D(Layer): + """ + This interface is used to construct a callable object of the ``ZeroPad2D`` class. + Pads the input tensor boundaries with zero. + + Parameters: + padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the + same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded. + The pad has the form (pad_left, pad_right, pad_top, pad_bottom). + data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data. + Default is "NCHW" + name (str, optional) : The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - x(Tensor): The input tensor of zeropad2d operator, which is a 4-D tensor. + The data type can be float32, float64. + - output(Tensor): The output tensor of zeropad2d operator, which is a 4-D tensor. + The data type is same as input x. + + Examples: + Examples are as follows. + + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + + input_shape = (1, 1, 2, 3) + pad = [1, 0, 1, 2] + data = paddle.arange(np.prod(input_shape), dtype="float32").reshape(input_shape) + 1 + + my_pad = nn.ZeroPad2D(padding=pad) + result = my_pad(data) + + print(result) + # [[[[0. 0. 0. 0.] + # [0. 1. 2. 3.] + # [0. 4. 5. 6.] + # [0. 0. 0. 0.] + # [0. 0. 0. 0.]]]] + """ + + def __init__(self, padding, data_format="NCHW", name=None): + super(ZeroPad2D, self).__init__() + self._pad = _npairs(padding, 2) + self._mode = 'constant' + self._value = 0. + self._data_format = data_format + self._name = name + + def forward(self, x): + return F.pad(x, + pad=self._pad, + mode=self._mode, + value=self._value, + data_format=self._data_format, + name=self._name) + + def extra_repr(self): + name_str = ', name={}'.format(self._name) if self._name else '' + return 'padding={}, data_format={}{}'.format( + self._pad, self._data_format, name_str) + + class Pad3D(Layer): """ This interface is used to construct a callable object of the ``Pad3D`` class.