提交 df52532c 编写于 作者: V Vijay Vasudevan 提交者: TensorFlower Gardener

Automated rollback of change 141622306

Change: 144145884
上级 b42ba8aa
......@@ -1752,6 +1752,8 @@ cuda_py_test(
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
],
)
......@@ -1916,6 +1918,8 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
......
......@@ -244,6 +244,22 @@ class CTCLossTest(test.TestCase):
(tf_loss, tf_loss_transposed) = sess.run([loss, loss_transposed])
self.assertAllEqual(tf_loss, tf_loss_transposed)
def testInvalidSecondGradient(self):
inputs = np.random.randn(2, 2, 3).astype(np.float32)
inputs_t = constant_op.constant(inputs)
labels = SimpleSparseTensorFrom([[0, 1], [1, 0]])
seq_lens = np.array([2, 2], dtype=np.int32)
v = [1.0]
with self.test_session(use_gpu=False):
loss = ctc_ops.ctc_loss(
inputs=inputs_t, labels=labels, sequence_length=seq_lens)
# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
".*No gradient defined.*PreventGradient.*"):
_ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
if __name__ == "__main__":
test.main()
......@@ -35,7 +35,9 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import app
from tensorflow.python.platform import test
......@@ -198,6 +200,23 @@ class SparseXentTest(test.TestCase):
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
def testSecondGradient(self):
images_placeholder = array_ops.placeholder(dtypes.float32, shape=(3, 2))
labels_placeholder = array_ops.placeholder(dtypes.int32, shape=(3))
weights = variables.Variable(random_ops.truncated_normal([2], stddev=1.0))
weights_with_zeros = array_ops.stack([array_ops.zeros([2]), weights],
axis=1)
logits = math_ops.matmul(images_placeholder, weights_with_zeros)
cross_entropy = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels_placeholder, logits=logits)
loss = math_ops.reduce_mean(cross_entropy)
# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
".*No gradient defined.*PreventGradient.*"):
_ = gradients_impl.hessians(loss, [weights])
def _testHighDim(self, features, labels):
np_loss, np_backprop = self._npXent(np.array(features), np.array(labels))
# manually reshape loss
......
......@@ -24,6 +24,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
......@@ -172,6 +174,26 @@ class XentTest(test.TestCase):
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
def testSecondGradient(self):
with self.test_session():
l = constant_op.constant([0.0, 0.0, 1.0, 0.0,
1.0, 0.0, 0.0, 0.0,
0.0, 0.5, 0.0, 0.5], shape=[12],
dtype=dtypes.float64, name="l")
f = constant_op.constant([0.1, 0.2, 0.3, 0.4,
0.1, 0.4, 0.9, 1.6,
0.1, 0.8, 2.7, 6.4], shape=[12],
dtype=dtypes.float64, name="f")
x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
name="xent")
loss = math_ops.reduce_mean(x)
# Taking ths second gradient should fail, since it is not
# yet supported.
with self.assertRaisesRegexp(LookupError,
".*No gradient defined.*PreventGradient.*"):
_ = gradients_impl.hessians(loss, [f])
def testWrapper(self):
features = np.array(
[[[1., 1., 1., 1.], [1., 2., 3., 4.]],
......
......@@ -160,10 +160,15 @@ def _CTCLossGrad(op, grad_loss, _):
The CTC Loss gradient.
"""
# Outputs are: loss, grad
grad = op.outputs[1]
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
grad_without_gradient = array_ops.prevent_gradient(op.outputs[1])
# Return gradient for inputs and None for
# labels_indices, labels_values and sequence_length
return [_BroadcastMul(grad_loss, grad), None, None, None]
return [_BroadcastMul(grad_loss, grad_without_gradient), None, None, None]
def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
......
......@@ -411,6 +411,16 @@ class StopGradientTest(test_util.TensorFlowTestCase):
assert igrad is None
class PreventGradientTest(test_util.TensorFlowTestCase):
def testPreventGradient(self):
with ops.Graph().as_default():
inp = constant(1.0, shape=[100, 32], name="in")
out = array_ops.prevent_gradient(inp)
with self.assertRaisesRegexp(LookupError, "No gradient defined"):
_ = gradients.gradients(out, inp)
class HessianVectorProductTest(test_util.TensorFlowTestCase):
def testHessianVectorProduct(self):
......
......@@ -322,18 +322,33 @@ def _BroadcastMul(vec, mat):
@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
"""Gradient function for SoftmaxCrossEntropyWithLogits."""
# grad_0 is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# There is no gradient for the labels
return _BroadcastMul(grad_0, op.outputs[1]), None
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
softmax_grad_without_gradient = array_ops.prevent_gradient(op.outputs[1])
return _BroadcastMul(grad_0, softmax_grad_without_gradient), None
@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
"""Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
# grad_0 is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# There is no gradient for the labels
return _BroadcastMul(grad_0, op.outputs[1]), None
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
op.outputs[1])
return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
@ops.RegisterGradient("Conv2D")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册