未验证 提交 52672ea5 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case17: paddle.nn.functional.interpolate (#49992)

* add dimension check for interpolate

* modify dimension check for interpolate

* add unittest to size check for interpolate

* fix incorrect shape check for interpolate

* split size check and add unittests
上级 dbfdefa7
...@@ -390,6 +390,12 @@ class TestBicubicOpError(unittest.TestCase): ...@@ -390,6 +390,12 @@ class TestBicubicOpError(unittest.TestCase):
x, size=[12, 12], mode='BICUBIC', align_corners=False x, size=[12, 12], mode='BICUBIC', align_corners=False
) )
def test_size_shape():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(
x, size=[12], mode='BICUBIC', align_corners=False
)
def test_align_corcers(): def test_align_corcers():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
interpolate(x, size=[12, 12], mode='BICUBIC', align_corners=3) interpolate(x, size=[12, 12], mode='BICUBIC', align_corners=3)
...@@ -481,6 +487,7 @@ class TestBicubicOpError(unittest.TestCase): ...@@ -481,6 +487,7 @@ class TestBicubicOpError(unittest.TestCase):
self.assertRaises(ValueError, test_mode_type) self.assertRaises(ValueError, test_mode_type)
self.assertRaises(ValueError, test_input_shape) self.assertRaises(ValueError, test_input_shape)
self.assertRaises(ValueError, test_size_shape)
self.assertRaises(TypeError, test_align_corcers) self.assertRaises(TypeError, test_align_corcers)
self.assertRaises(ValueError, test_attr_data_format) self.assertRaises(ValueError, test_attr_data_format)
self.assertRaises(TypeError, test_actual_shape) self.assertRaises(TypeError, test_actual_shape)
......
...@@ -610,6 +610,20 @@ class TestBicubicOpError(unittest.TestCase): ...@@ -610,6 +610,20 @@ class TestBicubicOpError(unittest.TestCase):
x, size={2, 2}, mode='bicubic', align_corners=False x, size={2, 2}, mode='bicubic', align_corners=False
) )
def test_size_length():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
out = interpolate(x, size=[2], mode='bicubic', align_corners=False)
def test_size_tensor_ndim():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
size = paddle.to_tensor(np.array([[2, 2]]))
out = interpolate(x, size=size, mode='bicubic', align_corners=False)
def test_size_tensor_length():
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
size = paddle.to_tensor(np.array([2]))
out = interpolate(x, size=size, mode='bicubic', align_corners=False)
def test_input_shape_1(): def test_input_shape_1():
x = fluid.data(name="x", shape=[2, 1, 0, 0], dtype="float32") x = fluid.data(name="x", shape=[2, 1, 0, 0], dtype="float32")
out = interpolate( out = interpolate(
...@@ -633,6 +647,9 @@ class TestBicubicOpError(unittest.TestCase): ...@@ -633,6 +647,9 @@ class TestBicubicOpError(unittest.TestCase):
self.assertRaises(ValueError, test_size_and_scale) self.assertRaises(ValueError, test_size_and_scale)
self.assertRaises(ValueError, test_size_and_scale2) self.assertRaises(ValueError, test_size_and_scale2)
self.assertRaises(TypeError, test_size_type) self.assertRaises(TypeError, test_size_type)
self.assertRaises(ValueError, test_size_length)
self.assertRaises(ValueError, test_size_tensor_ndim)
self.assertRaises(ValueError, test_size_tensor_length)
self.assertRaises(ValueError, test_input_shape_1) self.assertRaises(ValueError, test_input_shape_1)
def test_errors(self): def test_errors(self):
......
...@@ -397,6 +397,23 @@ def interpolate( ...@@ -397,6 +397,23 @@ def interpolate(
if size is None and scale_factor is None: if size is None and scale_factor is None:
raise ValueError("One of size and scale_factor must not be None.") raise ValueError("One of size and scale_factor must not be None.")
if (isinstance(size, list) or isinstance(size, tuple)) and len(
size
) != x.ndim - 2:
raise ValueError(
'The x and size should satisfy rank(x) - 2 == len(size).'
)
if isinstance(size, Variable):
if size.ndim != 1:
raise ValueError(
f"If size is a tensor, it's rank must be 1, but received {size.ndim}."
)
if size.shape[0] != x.ndim - 2:
raise ValueError(
'The x and size should satisfy rank(x) - 2 == size.shape[0].'
)
if not isinstance(align_corners, bool): if not isinstance(align_corners, bool):
raise TypeError("Attr align_corners should be a bool value") raise TypeError("Attr align_corners should be a bool value")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册