提交 25a8973d 编写于 作者: T Taehoon Lee 提交者: François Chollet

Increase test coverages by factorizing CNTK pads (#10259)

上级 abd02940
......@@ -2025,20 +2025,8 @@ def function(inputs, outputs, updates=[], **kwargs):
def temporal_padding(x, padding=(1, 1)):
assert len(padding) == 2
num_dynamic_axis = _get_dynamic_axis_num(x)
base_shape = x.shape
if num_dynamic_axis > 0:
assert len(base_shape) == 2
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[padding, (0, 0)])
else:
x = _padding(x, padding, 0)
else:
assert len(base_shape) == 3
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[(0, 0), padding, (0, 0)])
else:
x = _padding(x, padding, 1)
return x
assert len(x.shape) == 3 - (1 if num_dynamic_axis > 0 else 0)
return pad(x, [padding], 'channels_last', num_dynamic_axis)
def _padding(x, pattern, axis):
......@@ -2062,6 +2050,24 @@ def _padding(x, pattern, axis):
return x
def pad(x, pad_info, data_format, num_dynamic_axis):
if hasattr(C, 'pad'):
pattern = [list(p) for p in pad_info]
if data_format == 'channels_first':
pattern = [[0, 0]] + pattern
else:
pattern = pattern + [[0, 0]]
if num_dynamic_axis == 0:
pattern = [[0, 0]] + pattern
return C.pad(x, pattern=pattern)
else:
for (a, p) in enumerate(pad_info):
x = _padding(x, p,
a + (1 if num_dynamic_axis == 0 else 0) +
(1 if data_format == 'channels_first' else 0))
return x
def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
assert len(padding) == 2
assert len(padding[0]) == 2
......@@ -2072,38 +2078,8 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
raise ValueError('Unknown data_format ' + str(data_format))
num_dynamic_axis = _get_dynamic_axis_num(x)
base_shape = x.shape
if data_format == 'channels_first':
if num_dynamic_axis > 0:
assert len(base_shape) == 3
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1])])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
else:
assert len(base_shape) == 4
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], [0, 0], list(padding[0]), list(padding[1])])
else:
x = _padding(x, padding[0], 2)
x = _padding(x, padding[1], 3)
else:
if num_dynamic_axis > 0:
assert len(base_shape) == 3
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[list(padding[0]), list(padding[1]), [0, 0]])
else:
x = _padding(x, padding[0], 0)
x = _padding(x, padding[1], 1)
else:
assert len(base_shape) == 4
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1]), [0, 0]])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
return x
assert len(x.shape) == 4 - (1 if num_dynamic_axis > 0 else 0)
return pad(x, padding, data_format, num_dynamic_axis)
def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
......@@ -2117,42 +2093,8 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
raise ValueError('Unknown data_format ' + str(data_format))
num_dynamic_axis = _get_dynamic_axis_num(x)
base_shape = x.shape
if data_format == 'channels_first':
if num_dynamic_axis > 0:
assert len(base_shape) == 4
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1]), list(padding[2])])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
x = _padding(x, padding[2], 3)
else:
assert len(base_shape) == 5
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], [0, 0], list(padding[0]), list(padding[1]), list(padding[2])])
else:
x = _padding(x, padding[0], 2)
x = _padding(x, padding[1], 3)
x = _padding(x, padding[2], 4)
else:
if num_dynamic_axis > 0:
assert len(base_shape) == 4
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[list(padding[0]), list(padding[1]), list(padding[2]), [0, 0]])
else:
x = _padding(x, padding[0], 0)
x = _padding(x, padding[1], 1)
x = _padding(x, padding[2], 2)
else:
assert len(base_shape) == 5
if hasattr(C, 'pad'):
x = C.pad(x, pattern=[[0, 0], list(padding[0]), list(padding[1]), list(padding[2]), [0, 0]])
else:
x = _padding(x, padding[0], 1)
x = _padding(x, padding[1], 2)
x = _padding(x, padding[2], 3)
return x
assert len(x.shape) == 5 - (1 if num_dynamic_axis > 0 else 0)
return pad(x, padding, data_format, num_dynamic_axis)
def one_hot(indices, num_classes):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册