未验证 提交 82f0b5ea 编写于 作者: L littletomatodonkey 提交者: GitHub

adapt pad const (#28585)

* adapt pad const

* fix comment and rm fluid import

* rm stdout

* fix note
上级 bf143652
......@@ -251,7 +251,9 @@ class TestPadAPI(unittest.TestCase):
mode,
value=0,
data_format="NCDHW"):
if data_format == "NCDHW":
if mode == "constant" and len(pad) == len(input_data.shape) * 2:
pad = np.reshape(pad, (-1, 2)).tolist()
elif data_format == "NCDHW":
pad = [
(0, 0),
(0, 0),
......@@ -316,6 +318,7 @@ class TestPadAPI(unittest.TestCase):
paddle.disable_static()
input_shape = (1, 2, 3, 4, 5)
pad = [1, 2, 1, 1, 3, 4]
pad_3 = [1, 2, 1, 1, 3, 4, 5, 6, 7, 8]
mode = "constant"
value = 100
input_data = np.random.rand(*input_shape).astype(np.float32)
......@@ -323,6 +326,8 @@ class TestPadAPI(unittest.TestCase):
input_data, pad, mode, value, data_format="NCDHW")
np_out2 = self._get_numpy_out(
input_data, pad, mode, value, data_format="NDHWC")
np_out3 = self._get_numpy_out(
input_data, pad_3, mode, value, data_format="NCDHW")
tensor_data = paddle.to_tensor(input_data)
y1 = F.pad(tensor_data,
......@@ -335,14 +340,21 @@ class TestPadAPI(unittest.TestCase):
mode=mode,
value=value,
data_format="NDHWC")
y3 = F.pad(tensor_data,
pad=pad_3,
mode=mode,
value=value,
data_format="NCDHW")
self.assertTrue(np.allclose(y1.numpy(), np_out1))
self.assertTrue(np.allclose(y2.numpy(), np_out2))
self.assertTrue(np.allclose(y3.numpy(), np_out3))
def test_dygraph_2(self):
paddle.disable_static()
input_shape = (2, 3, 4, 5)
pad = [1, 1, 3, 4]
pad_3 = [1, 2, 1, 1, 3, 4, 5, 6]
mode = "constant"
value = 100
input_data = np.random.rand(*input_shape).astype(np.float32)
......@@ -350,6 +362,8 @@ class TestPadAPI(unittest.TestCase):
input_data, pad, mode, value, data_format="NCHW")
np_out2 = self._get_numpy_out(
input_data, pad, mode, value, data_format="NHWC")
np_out3 = self._get_numpy_out(
input_data, pad_3, mode, value, data_format="NCHW")
tensor_data = paddle.to_tensor(input_data)
tensor_pad = paddle.to_tensor(pad, dtype="int32")
......@@ -364,14 +378,21 @@ class TestPadAPI(unittest.TestCase):
mode=mode,
value=value,
data_format="NHWC")
y3 = F.pad(tensor_data,
pad=pad_3,
mode=mode,
value=value,
data_format="NCHW")
self.assertTrue(np.allclose(y1.numpy(), np_out1))
self.assertTrue(np.allclose(y2.numpy(), np_out2))
self.assertTrue(np.allclose(y3.numpy(), np_out3))
def test_dygraph_3(self):
paddle.disable_static()
input_shape = (3, 4, 5)
pad = [3, 4]
pad_3 = [3, 4, 5, 6, 7, 8]
mode = "constant"
value = 100
input_data = np.random.rand(*input_shape).astype(np.float32)
......@@ -379,6 +400,8 @@ class TestPadAPI(unittest.TestCase):
input_data, pad, mode, value, data_format="NCL")
np_out2 = self._get_numpy_out(
input_data, pad, mode, value, data_format="NLC")
np_out3 = self._get_numpy_out(
input_data, pad_3, mode, value, data_format="NCL")
tensor_data = paddle.to_tensor(input_data)
tensor_pad = paddle.to_tensor(pad, dtype="int32")
......@@ -392,9 +415,15 @@ class TestPadAPI(unittest.TestCase):
mode=mode,
value=value,
data_format="NLC")
y3 = F.pad(tensor_data,
pad=pad_3,
mode=mode,
value=value,
data_format="NCL")
self.assertTrue(np.allclose(y1.numpy(), np_out1))
self.assertTrue(np.allclose(y2.numpy(), np_out2))
self.assertTrue(np.allclose(y3.numpy(), np_out3))
class TestPad1dAPI(unittest.TestCase):
......
......@@ -1158,6 +1158,9 @@ def alpha_dropout(x, p=0.5, training=True, name=None):
def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
"""
Pad tensor according to 'pad' and 'mode'.
If mode is 'constant' and length of pad is twice as length of x dimension,
then the padding will be started from the first dimension and moved back onto x
according to 'pad' and 'value'.
If mode is 'reflect', pad[0] and pad[1] must be no greater
than width-1. The height and depth dimension has the same condition.
......@@ -1273,6 +1276,9 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
unsqueezed_dim = []
if mode == "constant" and isinstance(pad, list) and len(pad) == x_dim * 2:
return layers.pad(x, pad, pad_value=value)
if isinstance(pad, Variable):
if data_format in ["NCL", "NCHW", "NCDHW"]:
data_format = "NCDHW"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册