From 52a6ca0cf5277bfe200a4ad8f610feadb7e7435c Mon Sep 17 00:00:00 2001
From: littletomatodonkey <2120160898@bit.edu.cn>
Date: Fri, 28 Aug 2020 10:26:03 +0800
Subject: [PATCH] test=develop, improve pad assertion error (#26748)

---
 .../fluid/tests/unittests/test_pad3d_op.py    | 45 ++++++++++++++++++-
 python/paddle/nn/functional/common.py         | 14 +++++-
 2 files changed, 57 insertions(+), 2 deletions(-)

diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py
index 68589e6d818..11719a9c4a9 100644
--- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py
+++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py
@@ -166,7 +166,11 @@ class TestPadAPI(unittest.TestCase):
             value = 100
             input_data = np.random.rand(*input_shape).astype(np.float32)
             x = paddle.data(name="x", shape=input_shape)
-            result = F.pad(x=x, pad=pad, value=value, mode=mode)
+            result = F.pad(x=x,
+                           pad=pad,
+                           value=value,
+                           mode=mode,
+                           data_format="NCDHW")
             exe = Executor(place)
             fetches = exe.run(default_main_program(),
                               feed={"x": input_data},
@@ -666,5 +670,44 @@ class TestPad3dOpError(unittest.TestCase):
         self.assertRaises(Exception, test_reflect_3)
 
 
+class TestPadDataformatError(unittest.TestCase):
+    def test_errors(self):
+        def test_ncl():
+            paddle.disable_static(paddle.CPUPlace())
+            input_shape = (1, 2, 3, 4)
+            pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32'))
+            data = np.arange(
+                np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1
+            my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCL")
+            data = paddle.to_tensor(data)
+            result = my_pad(data)
+
+        def test_nchw():
+            paddle.disable_static(paddle.CPUPlace())
+            input_shape = (1, 2, 4)
+            pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32'))
+            data = np.arange(
+                np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1
+            my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCHW")
+            data = paddle.to_tensor(data)
+            result = my_pad(data)
+
+        def test_ncdhw():
+            paddle.disable_static(paddle.CPUPlace())
+            input_shape = (1, 2, 3, 4)
+            pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32'))
+            data = np.arange(
+                np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1
+            my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCDHW")
+            data = paddle.to_tensor(data)
+            result = my_pad(data)
+
+        self.assertRaises(AssertionError, test_ncl)
+
+        self.assertRaises(AssertionError, test_nchw)
+
+        self.assertRaises(AssertionError, test_ncdhw)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py
index 8408e224d87..623af3277fb 100644
--- a/python/paddle/nn/functional/common.py
+++ b/python/paddle/nn/functional/common.py
@@ -1230,7 +1230,19 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
 
     x_dim = len(x.shape)
 
-    original_data_format = data_format
+    assert x_dim in [
+        3, 4, 5
+    ], "input tesor dimension must be in [3, 4, 5] but got {}".format(x_dim)
+
+    supported_format_map = {
+        3: ["NCL", "NLC"],
+        4: ["NCHW", "NHWC"],
+        5: ["NCDHW", "NDHWC"],
+    }
+    assert data_format in supported_format_map[x_dim], \
+    "input tensor dimension is {}, it's data format should be in {} but got {}".format(
+        x_dim, supported_format_map[x_dim], data_format)
+
     unsqueezed_dim = []
 
     if isinstance(pad, Variable):
-- 
GitLab