未验证 提交 24c7087f 编写于 作者: W wangguanzhong 提交者: GitHub

fix index erro in conv2d_transpose (#34270)

上级 82339ed1
...@@ -3924,6 +3924,10 @@ def conv2d_transpose(input, ...@@ -3924,6 +3924,10 @@ def conv2d_transpose(input,
print(conv2d_transpose.shape) # [-1, 2, 34, 34] print(conv2d_transpose.shape) # [-1, 2, 34, 34]
""" """
assert param_attr is not False, "param_attr should not be False in conv2d_transpose." assert param_attr is not False, "param_attr should not be False in conv2d_transpose."
if len(input.shape) != 4:
raise ValueError("Input size should be 4, "
"but received {}".format(len(input.shape)))
if data_format not in ['NCHW', 'NHWC']: if data_format not in ['NCHW', 'NHWC']:
raise ValueError( raise ValueError(
"Attr(data_format) of Op(fluid.layers.conv2d_transpose) got wrong value: received " "Attr(data_format) of Op(fluid.layers.conv2d_transpose) got wrong value: received "
...@@ -4015,7 +4019,14 @@ def conv2d_transpose(input, ...@@ -4015,7 +4019,14 @@ def conv2d_transpose(input,
output_size = utils.convert_to_list(output_size, 2, 'output_size') output_size = utils.convert_to_list(output_size, 2, 'output_size')
else: else:
raise ValueError("output_size should be int, list[int] or tuple[int]") raise ValueError("output_size should be int, list[int] or tuple[int]")
groups = 1 if groups is None else groups
if groups is None:
groups = 1
elif groups <= 0:
raise ValueError("the groups of input must be greater than 0, "
"but received the groups of input is {}".format(
groups))
filter_shape = [input_channel, num_filters // groups] + filter_size filter_shape = [input_channel, num_filters // groups] + filter_size
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
......
...@@ -898,6 +898,25 @@ class TestConv2DTransposeOpException(unittest.TestCase): ...@@ -898,6 +898,25 @@ class TestConv2DTransposeOpException(unittest.TestCase):
self.assertRaises(ValueError, attr_padding_with_data_format) self.assertRaises(ValueError, attr_padding_with_data_format)
error_input = fluid.layers.data(
name='error_data', shape=[1], dtype="float32")
def error_input_size():
out = fluid.layers.conv2d_transpose(
input=error_input, groups=1, num_filters=6, filter_size=3)
self.assertRaises(ValueError, error_input_size)
def error_groups():
out = fluid.layers.conv2d_transpose(
input=data,
groups=0,
num_filters=6,
filter_size=3,
data_format='NHWC')
self.assertRaises(ValueError, error_groups)
class TestConv2DTransposeRepr(unittest.TestCase): class TestConv2DTransposeRepr(unittest.TestCase):
def test_case(self): def test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册