未验证 提交 a34ecad8 编写于 作者: L LoneRanger 提交者: GitHub

Fix Python IndexError of case19: paddle.nn.functional.conv2d_transpose (#50006)

上级 f76a7c5a
...@@ -1013,6 +1013,17 @@ class TestConv2DTransposeRepr(unittest.TestCase): ...@@ -1013,6 +1013,17 @@ class TestConv2DTransposeRepr(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestConv2dTranspose(unittest.TestCase):
def error_weight_input(self):
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [1, 1, 1, 1]), dtype='float32')
weight = paddle.to_tensor(np.reshape(array, [1]), dtype='float32')
paddle.nn.functional.conv2d_transpose(x, weight, bias=0)
def test_type_error(self):
self.assertRaises(ValueError, self.error_weight_input)
class TestTensorOutputSize1(UnittestBase): class TestTensorOutputSize1(UnittestBase):
def init_info(self): def init_info(self):
self.shapes = [[2, 3, 8, 8]] self.shapes = [[2, 3, 8, 8]]
......
...@@ -1204,6 +1204,12 @@ def conv2d_transpose( ...@@ -1204,6 +1204,12 @@ def conv2d_transpose(
x.shape x.shape
) )
) )
if len(weight.shape) != 4:
raise ValueError(
"Input weight should be 4D tensor, but received weight with the shape of {}".format(
weight.shape
)
)
num_channels = x.shape[channel_dim] num_channels = x.shape[channel_dim]
if num_channels < 0: if num_channels < 0:
raise ValueError( raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册