提交 df798f75 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Use fix graph for dropout if is_training has a constant value.

Change: 125571683
上级 91757d7b
......@@ -402,11 +402,18 @@ def dropout(inputs,
a tensor representing the output of the operation.
"""
with ops.op_scope([inputs], scope, 'Dropout') as sc:
is_training = ops.convert_to_tensor(is_training)
outputs = control_flow_ops.cond(
is_training,
lambda: nn.dropout(inputs, keep_prob, noise_shape),
lambda: inputs)
is_training_value = utils.constant_value(is_training, dtypes.bool)
if is_training_value is not None:
if is_training_value:
outputs = nn.dropout(inputs, keep_prob, noise_shape)
else:
outputs = inputs
else:
def _dropout():
return nn.dropout(inputs, keep_prob, noise_shape)
outputs = control_flow_ops.cond(is_training,
_dropout,
lambda: inputs)
return utils.collect_named_outputs(outputs_collections, sc, outputs)
......
......@@ -327,6 +327,24 @@ class DropoutTest(tf.test.TestCase):
with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images)
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
output.get_shape().assert_is_compatible_with(images.get_shape())
def testCreateDropoutWithConstant(self):
height, width = 3, 3
with self.test_session():
is_training = tf.constant(False)
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
output.get_shape().assert_is_compatible_with(images.get_shape())
def testCreateDropoutWithPlaceholder(self):
height, width = 3, 3
with self.test_session():
is_training = tf.placeholder(dtype=tf.bool, shape=[])
images = tf.random_uniform((5, height, width, 3), seed=1)
output = tf.contrib.layers.dropout(images, is_training=is_training)
self.assertEquals(output.op.name, 'Dropout/cond/Merge')
output.get_shape().assert_is_compatible_with(images.get_shape())
......
......@@ -52,6 +52,34 @@ def collect_named_outputs(collections, name, outputs):
return outputs
def constant_value(value_or_tensor, tensor_dtype=None):
"""Returns value if value_or_tensor has a constant value.
Args:
value_or_tensor: A value or a `Tensor`.
tensor_dtype: Optional `tf.dtype`, if set it would check the tensor type.
Returns:
The constant value or None if it not constant.
Raises:
ValueError: if value_or_tensor is None or the tensor has the wrong dtype.
"""
if value_or_tensor is None:
raise ValueError('value_or_tensor cannot be None')
value = value_or_tensor
if isinstance(value_or_tensor, ops.Tensor):
if tensor_dtype and value_or_tensor.dtype != tensor_dtype:
raise ValueError('The tensor has the wrong type %s instead of %s' % (
value_or_tensor.dtype, tensor_dtype))
if value_or_tensor.op.type == 'Const':
value_or_tensor.graph.prevent_feeding(value_or_tensor)
value = value_or_tensor.op.get_attr('value')
else:
value = None
return value
def get_variable_collections(variables_collections, name):
if isinstance(variables_collections, dict):
variable_collections = variables_collections[name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册