提交 58c4a4cb 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Bugfix: number of input channels is not necessarily in the last dimension,...

Bugfix: number of input channels is not necessarily in the last dimension, after introduction of data_format param.

PiperOrigin-RevId: 164889729
上级 8f9b1af8
......@@ -2548,7 +2548,8 @@ def separable_convolution2d(
dtype = inputs.dtype.base_dtype
kernel_h, kernel_w = utils.two_element_tuple(kernel_size)
stride_h, stride_w = utils.two_element_tuple(stride)
num_filters_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
num_filters_in = utils.channel_dimension(
inputs.get_shape(), df, min_rank=4)
weights_collections = utils.get_variable_collections(
variables_collections, 'weights')
......
......@@ -3230,12 +3230,13 @@ class SeparableConv2dTest(test.TestCase):
def testConvNCHW(self):
for num_filters, correct_output_filters in [(None, 6), (8, 8)]:
with self.test_session():
height, width = 3, 3
images = random_ops.random_uniform((5, 3, height, width), seed=1)
batch, height, width = 4, 5, 6
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
output = layers_lib.separable_conv2d(
images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW')
self.assertListEqual(
output.get_shape().as_list(), [5, correct_output_filters, 1, 1])
output.get_shape().as_list(), [batch, correct_output_filters,
height - 2, width - 2])
class ScaleGradientTests(test.TestCase):
......
......@@ -33,8 +33,8 @@ __all__ = ['collect_named_outputs',
'get_variable_collections',
'two_element_tuple',
'n_positive_integers',
'last_dimension',
'first_dimension']
'channel_dimension',
'last_dimension']
NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs'])
......@@ -220,15 +220,16 @@ def get_variable_collections(variables_collections, name):
return variable_collections
def first_dimension(shape, min_rank=1):
"""Returns the first dimension of shape while checking it has min_rank.
def _get_dimension(shape, dim, min_rank=1):
"""Returns the `dim` dimension of `shape`, while checking it has `min_rank`.
Args:
shape: A `TensorShape`.
dim: Integer, which dimension to return.
min_rank: Integer, minimum rank of shape.
Returns:
The value of the first dimension.
The value of the `dim` dimension.
Raises:
ValueError: if inputs don't have at least min_rank dimensions, or if the
......@@ -240,12 +241,32 @@ def first_dimension(shape, min_rank=1):
if len(dims) < min_rank:
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
len(dims)))
value = dims[0].value
value = dims[dim].value
if value is None:
raise ValueError('first dimension shape must be known but is None')
raise ValueError(
'dimension %d of shape must be known but is None: %s' % (dim, shape))
return value
def channel_dimension(shape, data_format, min_rank=1):
"""Returns the channel dimension of shape, while checking it has min_rank.
Args:
shape: A `TensorShape`.
data_format: `channels_first` or `channels_last`.
min_rank: Integer, minimum rank of shape.
Returns:
The value of the first dimension.
Raises:
ValueError: if inputs don't have at least min_rank dimensions, or if the
first dimension value is not defined.
"""
return _get_dimension(shape, 1 if data_format == 'channels_first' else -1,
min_rank=min_rank)
def last_dimension(shape, min_rank=1):
"""Returns the last dimension of shape while checking it has min_rank.
......@@ -260,16 +281,7 @@ def last_dimension(shape, min_rank=1):
ValueError: if inputs don't have at least min_rank dimensions, or if the
last dimension value is not defined.
"""
dims = shape.dims
if dims is None:
raise ValueError('dims of shape must be known but is None')
if len(dims) < min_rank:
raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
len(dims)))
value = dims[-1].value
if value is None:
raise ValueError('last dimension shape must be known but is None')
return value
return _get_dimension(shape, -1, min_rank=min_rank)
def two_element_tuple(int_or_tuple):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册