diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 26306259bc2879c5343dddd1f28b04eb530b975a..0499b1155427151e841e47a0113cdf8630a6e34d 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2548,7 +2548,8 @@ def separable_convolution2d( dtype = inputs.dtype.base_dtype kernel_h, kernel_w = utils.two_element_tuple(kernel_size) stride_h, stride_w = utils.two_element_tuple(stride) - num_filters_in = utils.last_dimension(inputs.get_shape(), min_rank=4) + num_filters_in = utils.channel_dimension( + inputs.get_shape(), df, min_rank=4) weights_collections = utils.get_variable_collections( variables_collections, 'weights') diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 4e19da4ec9de9804177092789d87301812faf056..2b6da3df26c0f70bf7fedf71d459ebf4ead9714a 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -3230,12 +3230,13 @@ class SeparableConv2dTest(test.TestCase): def testConvNCHW(self): for num_filters, correct_output_filters in [(None, 6), (8, 8)]: with self.test_session(): - height, width = 3, 3 - images = random_ops.random_uniform((5, 3, height, width), seed=1) + batch, height, width = 4, 5, 6 + images = random_ops.random_uniform((batch, 3, height, width), seed=1) output = layers_lib.separable_conv2d( images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW') self.assertListEqual( - output.get_shape().as_list(), [5, correct_output_filters, 1, 1]) + output.get_shape().as_list(), [batch, correct_output_filters, + height - 2, width - 2]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/utils.py b/tensorflow/contrib/layers/python/layers/utils.py index 9165cfa26767d746d67428550f276dbf2258ceb8..00d7c55a72cf53f009533a3112bcfbaba1dd0900 100644 --- a/tensorflow/contrib/layers/python/layers/utils.py +++ b/tensorflow/contrib/layers/python/layers/utils.py @@ -33,8 +33,8 @@ __all__ = ['collect_named_outputs', 'get_variable_collections', 'two_element_tuple', 'n_positive_integers', - 'last_dimension', - 'first_dimension'] + 'channel_dimension', + 'last_dimension'] NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs']) @@ -220,15 +220,16 @@ def get_variable_collections(variables_collections, name): return variable_collections -def first_dimension(shape, min_rank=1): - """Returns the first dimension of shape while checking it has min_rank. +def _get_dimension(shape, dim, min_rank=1): + """Returns the `dim` dimension of `shape`, while checking it has `min_rank`. Args: shape: A `TensorShape`. + dim: Integer, which dimension to return. min_rank: Integer, minimum rank of shape. Returns: - The value of the first dimension. + The value of the `dim` dimension. Raises: ValueError: if inputs don't have at least min_rank dimensions, or if the @@ -240,12 +241,32 @@ def first_dimension(shape, min_rank=1): if len(dims) < min_rank: raise ValueError('rank of shape must be at least %d not: %d' % (min_rank, len(dims))) - value = dims[0].value + value = dims[dim].value if value is None: - raise ValueError('first dimension shape must be known but is None') + raise ValueError( + 'dimension %d of shape must be known but is None: %s' % (dim, shape)) return value +def channel_dimension(shape, data_format, min_rank=1): + """Returns the channel dimension of shape, while checking it has min_rank. + + Args: + shape: A `TensorShape`. + data_format: `channels_first` or `channels_last`. + min_rank: Integer, minimum rank of shape. + + Returns: + The value of the first dimension. + + Raises: + ValueError: if inputs don't have at least min_rank dimensions, or if the + first dimension value is not defined. + """ + return _get_dimension(shape, 1 if data_format == 'channels_first' else -1, + min_rank=min_rank) + + def last_dimension(shape, min_rank=1): """Returns the last dimension of shape while checking it has min_rank. @@ -260,16 +281,7 @@ def last_dimension(shape, min_rank=1): ValueError: if inputs don't have at least min_rank dimensions, or if the last dimension value is not defined. """ - dims = shape.dims - if dims is None: - raise ValueError('dims of shape must be known but is None') - if len(dims) < min_rank: - raise ValueError('rank of shape must be at least %d not: %d' % (min_rank, - len(dims))) - value = dims[-1].value - if value is None: - raise ValueError('last dimension shape must be known but is None') - return value + return _get_dimension(shape, -1, min_rank=min_rank) def two_element_tuple(int_or_tuple):