From 8f77f8bc373ce7bdd42f1253be3083915b798b13 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 19 Apr 2022 14:37:03 +0800 Subject: [PATCH] fix pad3d infer shape (#41753) * fix pad3d infer shape --- paddle/phi/infermeta/unary.cc | 30 ++++++----- .../fluid/tests/unittests/test_pad3d_op.py | 54 ++++++++++++++++++- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 7b50a37ac14..e3e1211e3ec 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1284,24 +1284,33 @@ void Pad3dInferMeta(const MetaTensor& x, "5, but received %d. ", x_dim.size())); - std::vector out_dims(x_dim.size()); + std::vector out_dims(x_dim.size(), -1); out_dims[0] = x_dim[0]; + auto& paddings = paddings_int_array.GetData(); + if (data_format == "NCDHW") { + out_dims[1] = x_dim[1]; + } else { + out_dims[4] = x_dim[4]; + } if (paddings_int_array.FromTensor()) { if (config.is_runtime) { PADDLE_ENFORCE_EQ( - paddings_int_array.GetData().size(), + paddings.size(), 6, errors::InvalidArgument("Shape of Input(Paddings) should be equal to " "[6], but received [%d].", - paddings_int_array.GetData().size())); + paddings.size())); + if (data_format == "NCDHW") { + out_dims[2] = x_dim[2] + paddings[4] + paddings[5]; + out_dims[3] = x_dim[3] + paddings[2] + paddings[3]; + out_dims[4] = x_dim[4] + paddings[0] + paddings[1]; + } else { + out_dims[1] = x_dim[1] + paddings[4] + paddings[5]; + out_dims[2] = x_dim[2] + paddings[2] + paddings[3]; + out_dims[3] = x_dim[3] + paddings[0] + paddings[1]; + } } - out_dims[1] = x_dim[1]; - out_dims[2] = x_dim[2]; - out_dims[3] = x_dim[3]; - out_dims[4] = x_dim[4]; } else { - auto paddings = paddings_int_array.GetData(); - PADDLE_ENFORCE_EQ( paddings.size(), 6, @@ -1309,7 +1318,6 @@ void Pad3dInferMeta(const MetaTensor& x, "Size of paddings should be equal to 6, but received %d.", static_cast(paddings.size()))); if (data_format == "NCDHW") { - out_dims[1] = x_dim[1]; // channel out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0)) ? x_dim[2] : (x_dim[2] + paddings[4] + paddings[5]); // depth @@ -1322,8 +1330,6 @@ void Pad3dInferMeta(const MetaTensor& x, ? x_dim[4] : (x_dim[4] + paddings[0] + paddings[1]); // width } else { // NDHWC - out_dims[4] = x_dim[4]; // channel - out_dims[1] = ((!config.is_runtime) && (x_dim[1] < 0)) ? x_dim[1] : (x_dim[1] + paddings[4] + paddings[5]); // depth diff --git a/python/paddle/fluid/tests/unittests/test_pad3d_op.py b/python/paddle/fluid/tests/unittests/test_pad3d_op.py index 12f6f7b5721..eabff5f0021 100644 --- a/python/paddle/fluid/tests/unittests/test_pad3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pad3d_op.py @@ -27,7 +27,6 @@ class TestPad3dOp(OpTest): def setUp(self): paddle.enable_static() self.value = 0.0 - self.variable_paddings = False self.initTestCase() self.op_type = "pad3d" self.python_api = paddle.nn.functional.pad @@ -84,6 +83,7 @@ class TestPad3dOp(OpTest): self.mode = "constant" self.data_format = "NCDHW" self.pad_value = 0.0 + self.variable_paddings = False class TestCase1(TestPad3dOp): @@ -93,6 +93,7 @@ class TestCase1(TestPad3dOp): self.mode = "constant" self.data_format = "NCDHW" self.value = 1.0 + self.variable_paddings = False class TestCase2(TestPad3dOp): @@ -102,6 +103,7 @@ class TestCase2(TestPad3dOp): self.mode = "constant" self.data_format = "NDHWC" self.value = 1.0 + self.variable_paddings = False class TestCase3(TestPad3dOp): @@ -110,6 +112,7 @@ class TestCase3(TestPad3dOp): self.paddings = [0, 1, 1, 0, 2, 3] self.mode = "reflect" self.data_format = "NCDHW" + self.variable_paddings = False class TestCase4(TestPad3dOp): @@ -118,6 +121,7 @@ class TestCase4(TestPad3dOp): self.paddings = [0, 1, 2, 1, 2, 3] self.mode = "reflect" self.data_format = "NDHWC" + self.variable_paddings = False class TestCase5(TestPad3dOp): @@ -126,6 +130,7 @@ class TestCase5(TestPad3dOp): self.paddings = [0, 1, 2, 3, 2, 1] self.mode = "replicate" self.data_format = "NCDHW" + self.variable_paddings = False class TestCase6(TestPad3dOp): @@ -134,6 +139,7 @@ class TestCase6(TestPad3dOp): self.paddings = [5, 4, 2, 1, 2, 3] self.mode = "replicate" self.data_format = "NDHWC" + self.variable_paddings = False class TestCase7(TestPad3dOp): @@ -142,6 +148,7 @@ class TestCase7(TestPad3dOp): self.paddings = [0, 1, 2, 3, 2, 1] self.mode = "circular" self.data_format = "NCDHW" + self.variable_paddings = False class TestCase8(TestPad3dOp): @@ -150,6 +157,27 @@ class TestCase8(TestPad3dOp): self.paddings = [0, 1, 2, 1, 2, 3] self.mode = "circular" self.data_format = "NDHWC" + self.variable_paddings = False + + +class TestCase9(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 4, 5] + self.mode = "constant" + self.data_format = "NCDHW" + self.value = 1.0 + self.variable_paddings = True + + +class TestCase10(TestPad3dOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.paddings = [0, 1, 2, 3, 4, 5] + self.mode = "constant" + self.data_format = "NDHWC" + self.value = 1.0 + self.variable_paddings = True class TestPadAPI(unittest.TestCase): @@ -681,6 +709,30 @@ class TestPad3dAPI(unittest.TestCase): input_data, pad, "circular", data_format="NCDHW") self.assertTrue(np.allclose(output.numpy(), np_out)) + def test_pad_tensor(self): + paddle.disable_static() + for place in self.places: + input_shape = (3, 4, 5, 6, 7) + pad = [1, 2, 2, 1, 1, 0] + pad_tensor = paddle.to_tensor(pad) + input_data = np.random.rand(*input_shape).astype(np.float32) + + pad_reflection_ncdhw = nn.Pad3D( + padding=pad_tensor, mode="reflect", data_format="NCDHW") + pad_reflection_ndhwc = nn.Pad3D( + padding=pad_tensor, mode="reflect", data_format="NDHWC") + data = paddle.to_tensor(input_data) + + output = pad_reflection_ncdhw(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NCDHW") + self.assertTrue(np.allclose(output.numpy(), np_out)) + + output = pad_reflection_ndhwc(data) + np_out = self._get_numpy_out( + input_data, pad, "reflect", data_format="NDHWC") + self.assertTrue(np.allclose(output.numpy(), np_out)) + class TestPad3dOpError(unittest.TestCase): def setUp(self): -- GitLab