未验证 提交 ad50fa71 编写于 作者: L littletomatodonkey 提交者: GitHub

add int pad support for Pad1D/2D/3D (#31209)

* add int pad support for Pad1D/2D/3D

* fix type

* fix format
上级 d18c5e47
...@@ -467,12 +467,15 @@ class TestPad1dAPI(unittest.TestCase): ...@@ -467,12 +467,15 @@ class TestPad1dAPI(unittest.TestCase):
for place in self.places: for place in self.places:
input_shape = (3, 4, 5) input_shape = (3, 4, 5)
pad = [1, 2] pad = [1, 2]
pad_int = 1
value = 100 value = 100
input_data = np.random.rand(*input_shape).astype(np.float32) input_data = np.random.rand(*input_shape).astype(np.float32)
pad_reflection = nn.Pad1D(padding=pad, mode="reflect") pad_reflection = nn.Pad1D(padding=pad, mode="reflect")
pad_replication = nn.Pad1D(padding=pad, mode="replicate") pad_replication = nn.Pad1D(padding=pad, mode="replicate")
pad_constant = nn.Pad1D(padding=pad, mode="constant", value=value) pad_constant = nn.Pad1D(padding=pad, mode="constant", value=value)
pad_constant_int = nn.Pad1D(
padding=pad_int, mode="constant", value=value)
pad_circular = nn.Pad1D(padding=pad, mode="circular") pad_circular = nn.Pad1D(padding=pad, mode="circular")
data = paddle.to_tensor(input_data) data = paddle.to_tensor(input_data)
...@@ -492,6 +495,14 @@ class TestPad1dAPI(unittest.TestCase): ...@@ -492,6 +495,14 @@ class TestPad1dAPI(unittest.TestCase):
input_data, pad, "constant", value=value, data_format="NCL") input_data, pad, "constant", value=value, data_format="NCL")
self.assertTrue(np.allclose(output.numpy(), np_out)) self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_constant_int(data)
np_out = self._get_numpy_out(
input_data, [pad_int] * 2,
"constant",
value=value,
data_format="NCL")
self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_circular(data) output = pad_circular(data)
np_out = self._get_numpy_out( np_out = self._get_numpy_out(
input_data, pad, "circular", value=value, data_format="NCL") input_data, pad, "circular", value=value, data_format="NCL")
...@@ -541,12 +552,15 @@ class TestPad2dAPI(unittest.TestCase): ...@@ -541,12 +552,15 @@ class TestPad2dAPI(unittest.TestCase):
for place in self.places: for place in self.places:
input_shape = (3, 4, 5, 6) input_shape = (3, 4, 5, 6)
pad = [1, 2, 2, 1] pad = [1, 2, 2, 1]
pad_int = 1
value = 100 value = 100
input_data = np.random.rand(*input_shape).astype(np.float32) input_data = np.random.rand(*input_shape).astype(np.float32)
pad_reflection = nn.Pad2D(padding=pad, mode="reflect") pad_reflection = nn.Pad2D(padding=pad, mode="reflect")
pad_replication = nn.Pad2D(padding=pad, mode="replicate") pad_replication = nn.Pad2D(padding=pad, mode="replicate")
pad_constant = nn.Pad2D(padding=pad, mode="constant", value=value) pad_constant = nn.Pad2D(padding=pad, mode="constant", value=value)
pad_constant_int = nn.Pad2D(
padding=pad_int, mode="constant", value=value)
pad_circular = nn.Pad2D(padding=pad, mode="circular") pad_circular = nn.Pad2D(padding=pad, mode="circular")
data = paddle.to_tensor(input_data) data = paddle.to_tensor(input_data)
...@@ -566,6 +580,14 @@ class TestPad2dAPI(unittest.TestCase): ...@@ -566,6 +580,14 @@ class TestPad2dAPI(unittest.TestCase):
input_data, pad, "constant", value=value, data_format="NCHW") input_data, pad, "constant", value=value, data_format="NCHW")
self.assertTrue(np.allclose(output.numpy(), np_out)) self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_constant_int(data)
np_out = self._get_numpy_out(
input_data, [pad_int] * 4,
"constant",
value=value,
data_format="NCHW")
self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_circular(data) output = pad_circular(data)
np_out = self._get_numpy_out( np_out = self._get_numpy_out(
input_data, pad, "circular", data_format="NCHW") input_data, pad, "circular", data_format="NCHW")
...@@ -617,12 +639,15 @@ class TestPad3dAPI(unittest.TestCase): ...@@ -617,12 +639,15 @@ class TestPad3dAPI(unittest.TestCase):
for place in self.places: for place in self.places:
input_shape = (3, 4, 5, 6, 7) input_shape = (3, 4, 5, 6, 7)
pad = [1, 2, 2, 1, 1, 0] pad = [1, 2, 2, 1, 1, 0]
pad_int = 1
value = 100 value = 100
input_data = np.random.rand(*input_shape).astype(np.float32) input_data = np.random.rand(*input_shape).astype(np.float32)
pad_reflection = nn.Pad3D(padding=pad, mode="reflect") pad_reflection = nn.Pad3D(padding=pad, mode="reflect")
pad_replication = nn.Pad3D(padding=pad, mode="replicate") pad_replication = nn.Pad3D(padding=pad, mode="replicate")
pad_constant = nn.Pad3D(padding=pad, mode="constant", value=value) pad_constant = nn.Pad3D(padding=pad, mode="constant", value=value)
pad_constant_int = nn.Pad3D(
padding=pad_int, mode="constant", value=value)
pad_circular = nn.Pad3D(padding=pad, mode="circular") pad_circular = nn.Pad3D(padding=pad, mode="circular")
data = paddle.to_tensor(input_data) data = paddle.to_tensor(input_data)
...@@ -642,6 +667,14 @@ class TestPad3dAPI(unittest.TestCase): ...@@ -642,6 +667,14 @@ class TestPad3dAPI(unittest.TestCase):
input_data, pad, "constant", value=value, data_format="NCDHW") input_data, pad, "constant", value=value, data_format="NCDHW")
self.assertTrue(np.allclose(output.numpy(), np_out)) self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_constant_int(data)
np_out = self._get_numpy_out(
input_data, [pad_int] * 6,
"constant",
value=value,
data_format="NCDHW")
self.assertTrue(np.allclose(output.numpy(), np_out))
output = pad_circular(data) output = pad_circular(data)
np_out = self._get_numpy_out( np_out = self._get_numpy_out(
input_data, pad, "circular", data_format="NCDHW") input_data, pad, "circular", data_format="NCDHW")
......
...@@ -38,6 +38,13 @@ __all__ = [ ...@@ -38,6 +38,13 @@ __all__ = [
] ]
def _npairs(x, n):
if isinstance(x, (paddle.Tensor, list)):
return x
x = [x] * (n * 2)
return x
class Linear(layers.Layer): class Linear(layers.Layer):
r""" r"""
...@@ -915,7 +922,8 @@ class Pad1D(layers.Layer): ...@@ -915,7 +922,8 @@ class Pad1D(layers.Layer):
If mode is 'reflect', pad[0] and pad[1] must be no greater than width-1. If mode is 'reflect', pad[0] and pad[1] must be no greater than width-1.
Parameters: Parameters:
padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
same padding in both dimensions. Else [len(padding)/2] dimensions
of input will be padded. The pad has the form (pad_left, pad_right). of input will be padded. The pad has the form (pad_left, pad_right).
mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'.
When in 'constant' mode, this op uses a constant value to pad the input tensor. When in 'constant' mode, this op uses a constant value to pad the input tensor.
...@@ -968,7 +976,7 @@ class Pad1D(layers.Layer): ...@@ -968,7 +976,7 @@ class Pad1D(layers.Layer):
data_format="NCL", data_format="NCL",
name=None): name=None):
super(Pad1D, self).__init__() super(Pad1D, self).__init__()
self._pad = padding self._pad = _npairs(padding, 1)
self._mode = mode self._mode = mode
self._value = value self._value = value
self._data_format = data_format self._data_format = data_format
...@@ -996,8 +1004,9 @@ class Pad2D(layers.Layer): ...@@ -996,8 +1004,9 @@ class Pad2D(layers.Layer):
than width-1. The height dimension has the same condition. than width-1. The height dimension has the same condition.
Parameters: Parameters:
padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom). same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'.
When in 'constant' mode, this op uses a constant value to pad the input tensor. When in 'constant' mode, this op uses a constant value to pad the input tensor.
When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor. When in 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
...@@ -1051,7 +1060,7 @@ class Pad2D(layers.Layer): ...@@ -1051,7 +1060,7 @@ class Pad2D(layers.Layer):
data_format="NCHW", data_format="NCHW",
name=None): name=None):
super(Pad2D, self).__init__() super(Pad2D, self).__init__()
self._pad = padding self._pad = _npairs(padding, 2)
self._mode = mode self._mode = mode
self._value = value self._value = value
self._data_format = data_format self._data_format = data_format
...@@ -1079,7 +1088,8 @@ class Pad3D(layers.Layer): ...@@ -1079,7 +1088,8 @@ class Pad3D(layers.Layer):
than width-1. The height and depth dimension has the same condition. than width-1. The height and depth dimension has the same condition.
Parameters: Parameters:
padding (Tensor | List[int32]): The padding size with data type int32. [len(padding)/2] dimensions padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
same padding in all dimensions. Else [len(padding)/2] dimensions
of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back). of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back).
mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. mode (str): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'.
When in 'constant' mode, this op uses a constant value to pad the input tensor. When in 'constant' mode, this op uses a constant value to pad the input tensor.
...@@ -1134,7 +1144,7 @@ class Pad3D(layers.Layer): ...@@ -1134,7 +1144,7 @@ class Pad3D(layers.Layer):
data_format="NCDHW", data_format="NCDHW",
name=None): name=None):
super(Pad3D, self).__init__() super(Pad3D, self).__init__()
self._pad = padding self._pad = _npairs(padding, 3)
self._mode = mode self._mode = mode
self._value = value self._value = value
self._data_format = data_format self._data_format = data_format
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册