diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index e35e28e2ad59653ced984820271fe40018a5992a..338fb9f25ba42956bb7fc1400195c715362d949c 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -402,11 +402,18 @@ def dropout(inputs, a tensor representing the output of the operation. """ with ops.op_scope([inputs], scope, 'Dropout') as sc: - is_training = ops.convert_to_tensor(is_training) - outputs = control_flow_ops.cond( - is_training, - lambda: nn.dropout(inputs, keep_prob, noise_shape), - lambda: inputs) + is_training_value = utils.constant_value(is_training, dtypes.bool) + if is_training_value is not None: + if is_training_value: + outputs = nn.dropout(inputs, keep_prob, noise_shape) + else: + outputs = inputs + else: + def _dropout(): + return nn.dropout(inputs, keep_prob, noise_shape) + outputs = control_flow_ops.cond(is_training, + _dropout, + lambda: inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 9297c29f2233931ee8de470d0d3a3a824770b5dc..b9fc0e4fb5d03ae19bb4df97228948ecf88ac348 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -327,6 +327,24 @@ class DropoutTest(tf.test.TestCase): with self.test_session(): images = tf.random_uniform((5, height, width, 3), seed=1) output = tf.contrib.layers.dropout(images) + self.assertEquals(output.op.name, 'Dropout/dropout/mul_1') + output.get_shape().assert_is_compatible_with(images.get_shape()) + + def testCreateDropoutWithConstant(self): + height, width = 3, 3 + with self.test_session(): + is_training = tf.constant(False) + images = tf.random_uniform((5, height, width, 3), seed=1) + output = tf.contrib.layers.dropout(images, is_training=is_training) + self.assertEquals(output.op.name, 'Dropout/dropout/mul_1') + output.get_shape().assert_is_compatible_with(images.get_shape()) + + def testCreateDropoutWithPlaceholder(self): + height, width = 3, 3 + with self.test_session(): + is_training = tf.placeholder(dtype=tf.bool, shape=[]) + images = tf.random_uniform((5, height, width, 3), seed=1) + output = tf.contrib.layers.dropout(images, is_training=is_training) self.assertEquals(output.op.name, 'Dropout/cond/Merge') output.get_shape().assert_is_compatible_with(images.get_shape()) diff --git a/tensorflow/contrib/layers/python/layers/utils.py b/tensorflow/contrib/layers/python/layers/utils.py index 0b179159233350d8d5bd9d23267a6b1980fc2e4e..9ecd230a80b29a615200d6103c01622c0552aa44 100644 --- a/tensorflow/contrib/layers/python/layers/utils.py +++ b/tensorflow/contrib/layers/python/layers/utils.py @@ -52,6 +52,34 @@ def collect_named_outputs(collections, name, outputs): return outputs +def constant_value(value_or_tensor, tensor_dtype=None): + """Returns value if value_or_tensor has a constant value. + + Args: + value_or_tensor: A value or a `Tensor`. + tensor_dtype: Optional `tf.dtype`, if set it would check the tensor type. + + Returns: + The constant value or None if it not constant. + + Raises: + ValueError: if value_or_tensor is None or the tensor has the wrong dtype. + """ + if value_or_tensor is None: + raise ValueError('value_or_tensor cannot be None') + value = value_or_tensor + if isinstance(value_or_tensor, ops.Tensor): + if tensor_dtype and value_or_tensor.dtype != tensor_dtype: + raise ValueError('The tensor has the wrong type %s instead of %s' % ( + value_or_tensor.dtype, tensor_dtype)) + if value_or_tensor.op.type == 'Const': + value_or_tensor.graph.prevent_feeding(value_or_tensor) + value = value_or_tensor.op.get_attr('value') + else: + value = None + return value + + def get_variable_collections(variables_collections, name): if isinstance(variables_collections, dict): variable_collections = variables_collections[name]