未验证 提交 d89a759b 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix replicate pad when input size is 0 (#36510)

* fix replicate pad when input size is 0
* add unit test
上级 7edcc4fb
......@@ -565,13 +565,11 @@ class Pad3dCPUKernel : public framework::OpKernel<T> {
" in reflect mode"
", but received depth(%d) and pad_right(%d).",
in_width, pads[1]));
}
if (mode == "circular") {
PADDLE_ENFORCE_NE(
in_depth * in_height * in_width, 0,
platform::errors::InvalidArgument(
"The input tensor size can not be 0 for circular padding mode."));
} else if (mode == "circular" || mode == "replicate") {
PADDLE_ENFORCE_NE(in_depth * in_height * in_width, 0,
platform::errors::InvalidArgument(
"The input tensor size can not be 0 for circular "
"or replicate padding mode."));
}
const int pad_left = pads[0];
......
......@@ -618,13 +618,11 @@ class Pad3dCUDAKernel : public framework::OpKernel<T> {
" in reflect mode"
", but received depth(%d) and pad_right(%d).",
in_width, pads[1]));
}
if (mode == "circular") {
PADDLE_ENFORCE_NE(
in_depth * in_height * in_width, 0,
platform::errors::InvalidArgument(
"The input tensor size can not be 0 for circular padding mode."));
} else if (mode == "circular" || mode == "replicate") {
PADDLE_ENFORCE_NE(in_depth * in_height * in_width, 0,
platform::errors::InvalidArgument(
"The input tensor size can not be 0 for circular "
"or replicate padding mode."));
}
const int pad_left = pads[0];
......
......@@ -732,6 +732,15 @@ class TestPad3dOpError(unittest.TestCase):
mode='circular',
data_format="NCDHW")
def test_replicate_1():
input_shape = (1, 2, 0, 4, 5)
data = np.random.rand(*input_shape).astype(np.float32)
x = paddle.to_tensor(data)
y = F.pad(x,
pad=[1, 1, 1, 1, 2, 3],
mode='replicate',
data_format="NCDHW")
paddle.disable_static()
for place in self.places:
self.assertRaises(ValueError, test_variable)
......@@ -739,6 +748,7 @@ class TestPad3dOpError(unittest.TestCase):
self.assertRaises(Exception, test_reflect_2)
self.assertRaises(Exception, test_reflect_3)
self.assertRaises(Exception, test_circular_1)
self.assertRaises(Exception, test_replicate_1)
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册