提交 7b4cb655 编写于 作者: L liym27 提交者: 石晓伟

keep the size of symmetric padding is 2 for 2d and 3 for 3d. test=develop (#20903)

上级 8d1e9f0f
...@@ -2745,9 +2745,11 @@ def conv2d(input, ...@@ -2745,9 +2745,11 @@ def conv2d(input,
padding = padding[1:3] padding = padding[1:3]
padding = [ele for a_list in padding for ele in a_list] padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 4, 'padding') padding = utils.convert_to_list(padding, 4, 'padding')
if utils._is_symmetric_padding(padding, 2):
padding = [padding[0], padding[2]]
else: else:
padding = utils.convert_to_list(padding, 2, 'padding') padding = utils.convert_to_list(padding, 2, 'padding')
padding = [padding[0], padding[0], padding[1], padding[1]]
return padding return padding
...@@ -2760,10 +2762,10 @@ def conv2d(input, ...@@ -2760,10 +2762,10 @@ def conv2d(input,
str(padding)) str(padding))
if padding == "VALID": if padding == "VALID":
padding_algorithm = "VALID" padding_algorithm = "VALID"
padding = [0, 0, 0, 0] padding = [0, 0]
elif padding == "SAME": elif padding == "SAME":
padding_algorithm = "SAME" padding_algorithm = "SAME"
padding = [0, 0, 0, 0] padding = [0, 0]
padding = _update_padding(padding, data_format) padding = _update_padding(padding, data_format)
...@@ -2989,15 +2991,14 @@ def conv3d(input, ...@@ -2989,15 +2991,14 @@ def conv3d(input,
padding = padding[1:4] padding = padding[1:4]
padding = [ele for a_list in padding for ele in a_list] padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 6, 'padding') padding = utils.convert_to_list(padding, 6, 'padding')
if utils._is_symmetric_padding(padding, 3):
padding = [padding[0], padding[2], padding[4]]
elif is_list_or_tuple(padding) and len(padding) == 6: elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding') padding = utils.convert_to_list(padding, 6, 'padding')
if utils._is_symmetric_padding(padding, 3):
padding = [padding[0], padding[2], padding[4]]
else: else:
padding = utils.convert_to_list(padding, 3, 'padding') padding = utils.convert_to_list(padding, 3, 'padding')
padding = [
padding[0], padding[0], padding[1], padding[1], padding[2],
padding[2]
]
return padding return padding
...@@ -3010,10 +3011,10 @@ def conv3d(input, ...@@ -3010,10 +3011,10 @@ def conv3d(input,
str(padding)) str(padding))
if padding == "VALID": if padding == "VALID":
padding_algorithm = "VALID" padding_algorithm = "VALID"
padding = [0, 0, 0, 0, 0, 0] padding = [0, 0, 0]
elif padding == "SAME": elif padding == "SAME":
padding_algorithm = "SAME" padding_algorithm = "SAME"
padding = [0, 0, 0, 0, 0, 0] padding = [0, 0, 0]
padding = _update_padding(padding, data_format) padding = _update_padding(padding, data_format)
...@@ -3556,6 +3557,8 @@ def pool2d(input, ...@@ -3556,6 +3557,8 @@ def pool2d(input,
padding = [ele for a_list in padding for ele in a_list] padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 4, 'padding') padding = utils.convert_to_list(padding, 4, 'padding')
if utils._is_symmetric_padding(padding, 2):
padding = [padding[0], padding[2]]
else: else:
padding = utils.convert_to_list(padding, 2, 'padding') padding = utils.convert_to_list(padding, 2, 'padding')
...@@ -3570,14 +3573,14 @@ def pool2d(input, ...@@ -3570,14 +3573,14 @@ def pool2d(input,
% str(pool_padding)) % str(pool_padding))
if pool_padding == "VALID": if pool_padding == "VALID":
padding_algorithm = "VALID" padding_algorithm = "VALID"
pool_padding = [0, 0, 0, 0] pool_padding = [0, 0]
if ceil_mode != False: if ceil_mode != False:
raise ValueError( raise ValueError(
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode) must be False. " "When Attr(pool_padding) is \"VALID\", Attr(ceil_mode) must be False. "
"Received ceil_mode: True.") "Received ceil_mode: True.")
elif pool_padding == "SAME": elif pool_padding == "SAME":
padding_algorithm = "SAME" padding_algorithm = "SAME"
pool_padding = [0, 0, 0, 0] pool_padding = [0, 0]
pool_padding = update_padding(pool_padding, data_format) pool_padding = update_padding(pool_padding, data_format)
...@@ -3760,10 +3763,13 @@ def pool3d(input, ...@@ -3760,10 +3763,13 @@ def pool3d(input,
padding = padding[1:4] padding = padding[1:4]
padding = [ele for a_list in padding for ele in a_list] padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 6, 'padding') padding = utils.convert_to_list(padding, 6, 'padding')
if utils._is_symmetric_padding(padding, 3):
padding = [padding[0], padding[2], padding[4]]
elif is_list_or_tuple(padding) and len(padding) == 6: elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding') padding = utils.convert_to_list(padding, 6, 'padding')
if utils._is_symmetric_padding(padding, 3):
padding = [padding[0], padding[2], padding[4]]
else: else:
padding = utils.convert_to_list(padding, 3, 'padding') padding = utils.convert_to_list(padding, 3, 'padding')
...@@ -3778,14 +3784,14 @@ def pool3d(input, ...@@ -3778,14 +3784,14 @@ def pool3d(input,
% str(pool_padding)) % str(pool_padding))
if pool_padding == "VALID": if pool_padding == "VALID":
padding_algorithm = "VALID" padding_algorithm = "VALID"
pool_padding = [0, 0, 0, 0, 0, 0] pool_padding = [0, 0, 0]
if ceil_mode != False: if ceil_mode != False:
raise ValueError( raise ValueError(
"When Attr(pool_padding) is \"VALID\", ceil_mode must be False. " "When Attr(pool_padding) is \"VALID\", ceil_mode must be False. "
"Received ceil_mode: True.") "Received ceil_mode: True.")
elif pool_padding == "SAME": elif pool_padding == "SAME":
padding_algorithm = "SAME" padding_algorithm = "SAME"
pool_padding = [0, 0, 0, 0, 0, 0] pool_padding = [0, 0, 0]
pool_padding = update_padding(pool_padding, data_format) pool_padding = update_padding(pool_padding, data_format)
...@@ -5125,6 +5131,9 @@ def conv2d_transpose(input, ...@@ -5125,6 +5131,9 @@ def conv2d_transpose(input,
filter_size = utils.convert_to_list(filter_size, 2, filter_size = utils.convert_to_list(filter_size, 2,
'conv2d_transpose.filter_size') 'conv2d_transpose.filter_size')
if len(padding) == 4 and utils._is_symmetric_padding(padding, 2):
padding = [padding[0], padding[2]]
if output_size is None: if output_size is None:
output_size = [] output_size = []
elif isinstance(output_size, list) or isinstance(output_size, int): elif isinstance(output_size, list) or isinstance(output_size, int):
...@@ -5360,13 +5369,13 @@ def conv3d_transpose(input, ...@@ -5360,13 +5369,13 @@ def conv3d_transpose(input,
elif is_list_or_tuple(padding) and len(padding) == 6: elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding') padding = utils.convert_to_list(padding, 6, 'padding')
else: else:
padding = utils.convert_to_list(padding, 3, 'padding') padding = utils.convert_to_list(padding, 3, 'padding')
padding = [ padding = [
padding[0], padding[0], padding[1], padding[1], padding[2], padding[0], padding[0], padding[1], padding[1], padding[2],
padding[2] padding[2]
] ]
return padding return padding
padding_algorithm = "EXPLICIT" padding_algorithm = "EXPLICIT"
...@@ -5406,6 +5415,9 @@ def conv3d_transpose(input, ...@@ -5406,6 +5415,9 @@ def conv3d_transpose(input,
filter_size = utils.convert_to_list(filter_size, 3, filter_size = utils.convert_to_list(filter_size, 3,
'conv3d_transpose.filter_size') 'conv3d_transpose.filter_size')
if len(padding) == 6 and utils._is_symmetric_padding(padding, 3):
padding = [padding[0], padding[2], padding[4]]
groups = 1 if groups is None else groups groups = 1 if groups is None else groups
filter_shape = [input_channel, num_filters // groups] + filter_size filter_shape = [input_channel, num_filters // groups] + filter_size
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
......
...@@ -231,3 +231,16 @@ def assert_same_structure(nest1, nest2, check_types=True): ...@@ -231,3 +231,16 @@ def assert_same_structure(nest1, nest2, check_types=True):
"Second structure (%i elements): %s" % "Second structure (%i elements): %s" %
(len_nest1, nest1, len_nest2, nest2)) (len_nest1, nest1, len_nest2, nest2))
_recursive_assert_same_structure(nest1, nest2, check_types) _recursive_assert_same_structure(nest1, nest2, check_types)
def _is_symmetric_padding(padding, data_dim):
"""
Check whether padding is symmetrical.
"""
assert len(padding) == data_dim * 2 or len(padding) == data_dim
is_sys = True
if len(padding) == data_dim * 2:
for i in range(data_dim):
if padding[i * 2] != padding[i * 2 + 1]:
is_sys = False
return is_sys
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册