From 52a6ca0cf5277bfe200a4ad8f610feadb7e7435c Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Fri, 28 Aug 2020 10:26:03 +0800 Subject: [PATCH] test=develop, improve pad assertion error (#26748) --- .../fluid/tests/unittests/test_pad3d_op.py | 45 ++++++++++++++++++- python/paddle/nn/functional/common.py | 14 +++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 68589e6d818..11719a9c4a9 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -166,7 +166,11 @@ class TestPadAPI(unittest.TestCase): value = 100 input_data = np.random.rand(*input_shape).astype(np.float32) x = paddle.data(name="x", shape=input_shape) - result = F.pad(x=x, pad=pad, value=value, mode=mode) + result = F.pad(x=x, + pad=pad, + value=value, + mode=mode, + data_format="NCDHW") exe = Executor(place) fetches = exe.run(default_main_program(), feed={"x": input_data}, @@ -666,5 +670,44 @@ class TestPad3dOpError(unittest.TestCase): self.assertRaises(Exception, test_reflect_3) +class TestPadDataformatError(unittest.TestCase): + def test_errors(self): + def test_ncl(): + paddle.disable_static(paddle.CPUPlace()) + input_shape = (1, 2, 3, 4) + pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32')) + data = np.arange( + np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1 + my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCL") + data = paddle.to_tensor(data) + result = my_pad(data) + + def test_nchw(): + paddle.disable_static(paddle.CPUPlace()) + input_shape = (1, 2, 4) + pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32')) + data = np.arange( + np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1 + my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCHW") + data = paddle.to_tensor(data) + result = my_pad(data) + + def test_ncdhw(): + paddle.disable_static(paddle.CPUPlace()) + input_shape = (1, 2, 3, 4) + pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32')) + data = np.arange( + np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1 + my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCDHW") + data = paddle.to_tensor(data) + result = my_pad(data) + + self.assertRaises(AssertionError, test_ncl) + + self.assertRaises(AssertionError, test_nchw) + + self.assertRaises(AssertionError, test_ncdhw) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 8408e224d87..623af3277fb 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1230,7 +1230,19 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): x_dim = len(x.shape) - original_data_format = data_format + assert x_dim in [ + 3, 4, 5 + ], "input tesor dimension must be in [3, 4, 5] but got {}".format(x_dim) + + supported_format_map = { + 3: ["NCL", "NLC"], + 4: ["NCHW", "NHWC"], + 5: ["NCDHW", "NDHWC"], + } + assert data_format in supported_format_map[x_dim], \ + "input tensor dimension is {}, it's data format should be in {} but got {}".format( + x_dim, supported_format_map[x_dim], data_format) + unsqueezed_dim = [] if isinstance(pad, Variable): -- GitLab