未验证 提交 8f77f8bc 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix pad3d infer shape (#41753)

* fix pad3d infer shape
上级 03533b0c
...@@ -1284,24 +1284,33 @@ void Pad3dInferMeta(const MetaTensor& x, ...@@ -1284,24 +1284,33 @@ void Pad3dInferMeta(const MetaTensor& x,
"5, but received %d. ", "5, but received %d. ",
x_dim.size())); x_dim.size()));
std::vector<int64_t> out_dims(x_dim.size()); std::vector<int64_t> out_dims(x_dim.size(), -1);
out_dims[0] = x_dim[0]; out_dims[0] = x_dim[0];
auto& paddings = paddings_int_array.GetData();
if (data_format == "NCDHW") {
out_dims[1] = x_dim[1];
} else {
out_dims[4] = x_dim[4];
}
if (paddings_int_array.FromTensor()) { if (paddings_int_array.FromTensor()) {
if (config.is_runtime) { if (config.is_runtime) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddings_int_array.GetData().size(), paddings.size(),
6, 6,
errors::InvalidArgument("Shape of Input(Paddings) should be equal to " errors::InvalidArgument("Shape of Input(Paddings) should be equal to "
"[6], but received [%d].", "[6], but received [%d].",
paddings_int_array.GetData().size())); paddings.size()));
if (data_format == "NCDHW") {
out_dims[2] = x_dim[2] + paddings[4] + paddings[5];
out_dims[3] = x_dim[3] + paddings[2] + paddings[3];
out_dims[4] = x_dim[4] + paddings[0] + paddings[1];
} else {
out_dims[1] = x_dim[1] + paddings[4] + paddings[5];
out_dims[2] = x_dim[2] + paddings[2] + paddings[3];
out_dims[3] = x_dim[3] + paddings[0] + paddings[1];
}
} }
out_dims[1] = x_dim[1];
out_dims[2] = x_dim[2];
out_dims[3] = x_dim[3];
out_dims[4] = x_dim[4];
} else { } else {
auto paddings = paddings_int_array.GetData();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddings.size(), paddings.size(),
6, 6,
...@@ -1309,7 +1318,6 @@ void Pad3dInferMeta(const MetaTensor& x, ...@@ -1309,7 +1318,6 @@ void Pad3dInferMeta(const MetaTensor& x,
"Size of paddings should be equal to 6, but received %d.", "Size of paddings should be equal to 6, but received %d.",
static_cast<int>(paddings.size()))); static_cast<int>(paddings.size())));
if (data_format == "NCDHW") { if (data_format == "NCDHW") {
out_dims[1] = x_dim[1]; // channel
out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0)) out_dims[2] = ((!config.is_runtime) && (x_dim[2] < 0))
? x_dim[2] ? x_dim[2]
: (x_dim[2] + paddings[4] + paddings[5]); // depth : (x_dim[2] + paddings[4] + paddings[5]); // depth
...@@ -1322,8 +1330,6 @@ void Pad3dInferMeta(const MetaTensor& x, ...@@ -1322,8 +1330,6 @@ void Pad3dInferMeta(const MetaTensor& x,
? x_dim[4] ? x_dim[4]
: (x_dim[4] + paddings[0] + paddings[1]); // width : (x_dim[4] + paddings[0] + paddings[1]); // width
} else { // NDHWC } else { // NDHWC
out_dims[4] = x_dim[4]; // channel
out_dims[1] = ((!config.is_runtime) && (x_dim[1] < 0)) out_dims[1] = ((!config.is_runtime) && (x_dim[1] < 0))
? x_dim[1] ? x_dim[1]
: (x_dim[1] + paddings[4] + paddings[5]); // depth : (x_dim[1] + paddings[4] + paddings[5]); // depth
......
...@@ -27,7 +27,6 @@ class TestPad3dOp(OpTest): ...@@ -27,7 +27,6 @@ class TestPad3dOp(OpTest):
def setUp(self): def setUp(self):
paddle.enable_static() paddle.enable_static()
self.value = 0.0 self.value = 0.0
self.variable_paddings = False
self.initTestCase() self.initTestCase()
self.op_type = "pad3d" self.op_type = "pad3d"
self.python_api = paddle.nn.functional.pad self.python_api = paddle.nn.functional.pad
...@@ -84,6 +83,7 @@ class TestPad3dOp(OpTest): ...@@ -84,6 +83,7 @@ class TestPad3dOp(OpTest):
self.mode = "constant" self.mode = "constant"
self.data_format = "NCDHW" self.data_format = "NCDHW"
self.pad_value = 0.0 self.pad_value = 0.0
self.variable_paddings = False
class TestCase1(TestPad3dOp): class TestCase1(TestPad3dOp):
...@@ -93,6 +93,7 @@ class TestCase1(TestPad3dOp): ...@@ -93,6 +93,7 @@ class TestCase1(TestPad3dOp):
self.mode = "constant" self.mode = "constant"
self.data_format = "NCDHW" self.data_format = "NCDHW"
self.value = 1.0 self.value = 1.0
self.variable_paddings = False
class TestCase2(TestPad3dOp): class TestCase2(TestPad3dOp):
...@@ -102,6 +103,7 @@ class TestCase2(TestPad3dOp): ...@@ -102,6 +103,7 @@ class TestCase2(TestPad3dOp):
self.mode = "constant" self.mode = "constant"
self.data_format = "NDHWC" self.data_format = "NDHWC"
self.value = 1.0 self.value = 1.0
self.variable_paddings = False
class TestCase3(TestPad3dOp): class TestCase3(TestPad3dOp):
...@@ -110,6 +112,7 @@ class TestCase3(TestPad3dOp): ...@@ -110,6 +112,7 @@ class TestCase3(TestPad3dOp):
self.paddings = [0, 1, 1, 0, 2, 3] self.paddings = [0, 1, 1, 0, 2, 3]
self.mode = "reflect" self.mode = "reflect"
self.data_format = "NCDHW" self.data_format = "NCDHW"
self.variable_paddings = False
class TestCase4(TestPad3dOp): class TestCase4(TestPad3dOp):
...@@ -118,6 +121,7 @@ class TestCase4(TestPad3dOp): ...@@ -118,6 +121,7 @@ class TestCase4(TestPad3dOp):
self.paddings = [0, 1, 2, 1, 2, 3] self.paddings = [0, 1, 2, 1, 2, 3]
self.mode = "reflect" self.mode = "reflect"
self.data_format = "NDHWC" self.data_format = "NDHWC"
self.variable_paddings = False
class TestCase5(TestPad3dOp): class TestCase5(TestPad3dOp):
...@@ -126,6 +130,7 @@ class TestCase5(TestPad3dOp): ...@@ -126,6 +130,7 @@ class TestCase5(TestPad3dOp):
self.paddings = [0, 1, 2, 3, 2, 1] self.paddings = [0, 1, 2, 3, 2, 1]
self.mode = "replicate" self.mode = "replicate"
self.data_format = "NCDHW" self.data_format = "NCDHW"
self.variable_paddings = False
class TestCase6(TestPad3dOp): class TestCase6(TestPad3dOp):
...@@ -134,6 +139,7 @@ class TestCase6(TestPad3dOp): ...@@ -134,6 +139,7 @@ class TestCase6(TestPad3dOp):
self.paddings = [5, 4, 2, 1, 2, 3] self.paddings = [5, 4, 2, 1, 2, 3]
self.mode = "replicate" self.mode = "replicate"
self.data_format = "NDHWC" self.data_format = "NDHWC"
self.variable_paddings = False
class TestCase7(TestPad3dOp): class TestCase7(TestPad3dOp):
...@@ -142,6 +148,7 @@ class TestCase7(TestPad3dOp): ...@@ -142,6 +148,7 @@ class TestCase7(TestPad3dOp):
self.paddings = [0, 1, 2, 3, 2, 1] self.paddings = [0, 1, 2, 3, 2, 1]
self.mode = "circular" self.mode = "circular"
self.data_format = "NCDHW" self.data_format = "NCDHW"
self.variable_paddings = False
class TestCase8(TestPad3dOp): class TestCase8(TestPad3dOp):
...@@ -150,6 +157,27 @@ class TestCase8(TestPad3dOp): ...@@ -150,6 +157,27 @@ class TestCase8(TestPad3dOp):
self.paddings = [0, 1, 2, 1, 2, 3] self.paddings = [0, 1, 2, 1, 2, 3]
self.mode = "circular" self.mode = "circular"
self.data_format = "NDHWC" self.data_format = "NDHWC"
self.variable_paddings = False
class TestCase9(TestPad3dOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.paddings = [0, 1, 2, 3, 4, 5]
self.mode = "constant"
self.data_format = "NCDHW"
self.value = 1.0
self.variable_paddings = True
class TestCase10(TestPad3dOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.paddings = [0, 1, 2, 3, 4, 5]
self.mode = "constant"
self.data_format = "NDHWC"
self.value = 1.0
self.variable_paddings = True
class TestPadAPI(unittest.TestCase): class TestPadAPI(unittest.TestCase):
...@@ -681,6 +709,30 @@ class TestPad3dAPI(unittest.TestCase): ...@@ -681,6 +709,30 @@ class TestPad3dAPI(unittest.TestCase):
input_data, pad, "circular", data_format="NCDHW") input_data, pad, "circular", data_format="NCDHW")
self.assertTrue(np.allclose(output.numpy(), np_out)) self.assertTrue(np.allclose(output.numpy(), np_out))
def test_pad_tensor(self):
paddle.disable_static()
for place in self.places:
input_shape = (3, 4, 5, 6, 7)
pad = [1, 2, 2, 1, 1, 0]
pad_tensor = paddle.to_tensor(pad)
input_data = np.random.rand(*input_shape).astype(np.float32)
pad_reflection_ncdhw = nn.Pad3D(
padding=pad_tensor, mode="reflect", data_format="NCDHW")
pad_reflection_ndhwc = nn.Pad3D(
padding=pad_tensor, mode="reflect", data_format="NDHWC")
data = paddle.to_tensor(input_data)
output = pad_reflection_ncdhw(data)
np_out = self._get_numpy_out(
input_data, pad, "reflect", data_format="NCDHW")
self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_reflection_ndhwc(data)
np_out = self._get_numpy_out(
input_data, pad, "reflect", data_format="NDHWC")
self.assertTrue(np.allclose(output.numpy(), np_out))
class TestPad3dOpError(unittest.TestCase): class TestPad3dOpError(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册