未验证 提交 52a6ca0c 编写于 作者: L littletomatodonkey 提交者: GitHub

test=develop, improve pad assertion error (#26748)

上级 c03092b1
......@@ -166,7 +166,11 @@ class TestPadAPI(unittest.TestCase):
value = 100
input_data = np.random.rand(*input_shape).astype(np.float32)
x = paddle.data(name="x", shape=input_shape)
result = F.pad(x=x, pad=pad, value=value, mode=mode)
result = F.pad(x=x,
pad=pad,
value=value,
mode=mode,
data_format="NCDHW")
exe = Executor(place)
fetches = exe.run(default_main_program(),
feed={"x": input_data},
......@@ -666,5 +670,44 @@ class TestPad3dOpError(unittest.TestCase):
self.assertRaises(Exception, test_reflect_3)
class TestPadDataformatError(unittest.TestCase):
def test_errors(self):
def test_ncl():
paddle.disable_static(paddle.CPUPlace())
input_shape = (1, 2, 3, 4)
pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32'))
data = np.arange(
np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1
my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCL")
data = paddle.to_tensor(data)
result = my_pad(data)
def test_nchw():
paddle.disable_static(paddle.CPUPlace())
input_shape = (1, 2, 4)
pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32'))
data = np.arange(
np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1
my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCHW")
data = paddle.to_tensor(data)
result = my_pad(data)
def test_ncdhw():
paddle.disable_static(paddle.CPUPlace())
input_shape = (1, 2, 3, 4)
pad = paddle.to_tensor(np.array([2, 1, 2, 1]).astype('int32'))
data = np.arange(
np.prod(input_shape), dtype=np.float64).reshape(input_shape) + 1
my_pad = nn.ReplicationPad1d(padding=pad, data_format="NCDHW")
data = paddle.to_tensor(data)
result = my_pad(data)
self.assertRaises(AssertionError, test_ncl)
self.assertRaises(AssertionError, test_nchw)
self.assertRaises(AssertionError, test_ncdhw)
if __name__ == '__main__':
unittest.main()
......@@ -1230,7 +1230,19 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
x_dim = len(x.shape)
original_data_format = data_format
assert x_dim in [
3, 4, 5
], "input tesor dimension must be in [3, 4, 5] but got {}".format(x_dim)
supported_format_map = {
3: ["NCL", "NLC"],
4: ["NCHW", "NHWC"],
5: ["NCDHW", "NDHWC"],
}
assert data_format in supported_format_map[x_dim], \
"input tensor dimension is {}, it's data format should be in {} but got {}".format(
x_dim, supported_format_map[x_dim], data_format)
unsqueezed_dim = []
if isinstance(pad, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册