Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
df52532c
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
df52532c
编写于
1月 10, 2017
作者:
V
Vijay Vasudevan
提交者:
TensorFlower Gardener
1月 10, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated rollback of change 141622306
Change: 144145884
上级
b42ba8aa
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
95 addition
and
4 deletion
+95
-4
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/BUILD
+4
-0
tensorflow/python/kernel_tests/ctc_loss_op_test.py
tensorflow/python/kernel_tests/ctc_loss_op_test.py
+16
-0
tensorflow/python/kernel_tests/sparse_xent_op_test.py
tensorflow/python/kernel_tests/sparse_xent_op_test.py
+19
-0
tensorflow/python/kernel_tests/xent_op_test.py
tensorflow/python/kernel_tests/xent_op_test.py
+22
-0
tensorflow/python/ops/ctc_ops.py
tensorflow/python/ops/ctc_ops.py
+7
-2
tensorflow/python/ops/gradients_test.py
tensorflow/python/ops/gradients_test.py
+10
-0
tensorflow/python/ops/nn_grad.py
tensorflow/python/ops/nn_grad.py
+17
-2
未找到文件。
tensorflow/python/kernel_tests/BUILD
浏览文件 @
df52532c
...
...
@@ -1752,6 +1752,8 @@ cuda_py_test(
"//tensorflow/python:nn_ops_gen"
,
"//tensorflow/python:platform"
,
"//tensorflow/python:sparse_ops"
,
"//tensorflow/python:random_ops"
,
"//tensorflow/python:variables"
,
],
)
...
...
@@ -1916,6 +1918,8 @@ cuda_py_test(
"//third_party/py/numpy"
,
"//tensorflow/python:client_testlib"
,
"//tensorflow/python:framework_for_generated_wrappers"
,
"//tensorflow/python:gradients"
,
"//tensorflow/python:math_ops"
,
"//tensorflow/python:nn_grad"
,
"//tensorflow/python:nn_ops"
,
"//tensorflow/python:nn_ops_gen"
,
...
...
tensorflow/python/kernel_tests/ctc_loss_op_test.py
浏览文件 @
df52532c
...
...
@@ -244,6 +244,22 @@ class CTCLossTest(test.TestCase):
(
tf_loss
,
tf_loss_transposed
)
=
sess
.
run
([
loss
,
loss_transposed
])
self
.
assertAllEqual
(
tf_loss
,
tf_loss_transposed
)
def
testInvalidSecondGradient
(
self
):
inputs
=
np
.
random
.
randn
(
2
,
2
,
3
).
astype
(
np
.
float32
)
inputs_t
=
constant_op
.
constant
(
inputs
)
labels
=
SimpleSparseTensorFrom
([[
0
,
1
],
[
1
,
0
]])
seq_lens
=
np
.
array
([
2
,
2
],
dtype
=
np
.
int32
)
v
=
[
1.0
]
with
self
.
test_session
(
use_gpu
=
False
):
loss
=
ctc_ops
.
ctc_loss
(
inputs
=
inputs_t
,
labels
=
labels
,
sequence_length
=
seq_lens
)
# Taking ths second gradient should fail, since it is not
# yet supported.
with
self
.
assertRaisesRegexp
(
LookupError
,
".*No gradient defined.*PreventGradient.*"
):
_
=
gradients_impl
.
_hessian_vector_product
(
loss
,
[
inputs_t
],
v
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/kernel_tests/sparse_xent_op_test.py
浏览文件 @
df52532c
...
...
@@ -35,7 +35,9 @@ from tensorflow.python.ops import gradient_checker
from
tensorflow.python.ops
import
gradients_impl
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
nn_ops
from
tensorflow.python.ops
import
random_ops
from
tensorflow.python.ops
import
sparse_ops
from
tensorflow.python.ops
import
variables
import
tensorflow.python.ops.nn_grad
# pylint: disable=unused-import
from
tensorflow.python.platform
import
app
from
tensorflow.python.platform
import
test
...
...
@@ -198,6 +200,23 @@ class SparseXentTest(test.TestCase):
print
(
"cross entropy gradient err = "
,
err
)
self
.
assertLess
(
err
,
5e-8
)
def
testSecondGradient
(
self
):
images_placeholder
=
array_ops
.
placeholder
(
dtypes
.
float32
,
shape
=
(
3
,
2
))
labels_placeholder
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
(
3
))
weights
=
variables
.
Variable
(
random_ops
.
truncated_normal
([
2
],
stddev
=
1.0
))
weights_with_zeros
=
array_ops
.
stack
([
array_ops
.
zeros
([
2
]),
weights
],
axis
=
1
)
logits
=
math_ops
.
matmul
(
images_placeholder
,
weights_with_zeros
)
cross_entropy
=
nn_ops
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
labels_placeholder
,
logits
=
logits
)
loss
=
math_ops
.
reduce_mean
(
cross_entropy
)
# Taking ths second gradient should fail, since it is not
# yet supported.
with
self
.
assertRaisesRegexp
(
LookupError
,
".*No gradient defined.*PreventGradient.*"
):
_
=
gradients_impl
.
hessians
(
loss
,
[
weights
])
def
_testHighDim
(
self
,
features
,
labels
):
np_loss
,
np_backprop
=
self
.
_npXent
(
np
.
array
(
features
),
np
.
array
(
labels
))
# manually reshape loss
...
...
tensorflow/python/kernel_tests/xent_op_test.py
浏览文件 @
df52532c
...
...
@@ -24,6 +24,8 @@ from tensorflow.python.framework import constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.ops
import
gen_nn_ops
from
tensorflow.python.ops
import
gradient_checker
from
tensorflow.python.ops
import
gradients_impl
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
nn_ops
import
tensorflow.python.ops.nn_grad
# pylint: disable=unused-import
from
tensorflow.python.platform
import
test
...
...
@@ -172,6 +174,26 @@ class XentTest(test.TestCase):
print
(
"cross entropy gradient err = "
,
err
)
self
.
assertLess
(
err
,
5e-8
)
def
testSecondGradient
(
self
):
with
self
.
test_session
():
l
=
constant_op
.
constant
([
0.0
,
0.0
,
1.0
,
0.0
,
1.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.5
,
0.0
,
0.5
],
shape
=
[
12
],
dtype
=
dtypes
.
float64
,
name
=
"l"
)
f
=
constant_op
.
constant
([
0.1
,
0.2
,
0.3
,
0.4
,
0.1
,
0.4
,
0.9
,
1.6
,
0.1
,
0.8
,
2.7
,
6.4
],
shape
=
[
12
],
dtype
=
dtypes
.
float64
,
name
=
"f"
)
x
=
nn_ops
.
softmax_cross_entropy_with_logits
(
labels
=
l
,
logits
=
f
,
name
=
"xent"
)
loss
=
math_ops
.
reduce_mean
(
x
)
# Taking ths second gradient should fail, since it is not
# yet supported.
with
self
.
assertRaisesRegexp
(
LookupError
,
".*No gradient defined.*PreventGradient.*"
):
_
=
gradients_impl
.
hessians
(
loss
,
[
f
])
def
testWrapper
(
self
):
features
=
np
.
array
(
[[[
1.
,
1.
,
1.
,
1.
],
[
1.
,
2.
,
3.
,
4.
]],
...
...
tensorflow/python/ops/ctc_ops.py
浏览文件 @
df52532c
...
...
@@ -160,10 +160,15 @@ def _CTCLossGrad(op, grad_loss, _):
The CTC Loss gradient.
"""
# Outputs are: loss, grad
grad
=
op
.
outputs
[
1
]
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
grad_without_gradient
=
array_ops
.
prevent_gradient
(
op
.
outputs
[
1
])
# Return gradient for inputs and None for
# labels_indices, labels_values and sequence_length
return
[
_BroadcastMul
(
grad_loss
,
grad
),
None
,
None
,
None
]
return
[
_BroadcastMul
(
grad_loss
,
grad
_without_gradient
),
None
,
None
,
None
]
def
ctc_greedy_decoder
(
inputs
,
sequence_length
,
merge_repeated
=
True
):
...
...
tensorflow/python/ops/gradients_test.py
浏览文件 @
df52532c
...
...
@@ -411,6 +411,16 @@ class StopGradientTest(test_util.TensorFlowTestCase):
assert
igrad
is
None
class
PreventGradientTest
(
test_util
.
TensorFlowTestCase
):
def
testPreventGradient
(
self
):
with
ops
.
Graph
().
as_default
():
inp
=
constant
(
1.0
,
shape
=
[
100
,
32
],
name
=
"in"
)
out
=
array_ops
.
prevent_gradient
(
inp
)
with
self
.
assertRaisesRegexp
(
LookupError
,
"No gradient defined"
):
_
=
gradients
.
gradients
(
out
,
inp
)
class
HessianVectorProductTest
(
test_util
.
TensorFlowTestCase
):
def
testHessianVectorProduct
(
self
):
...
...
tensorflow/python/ops/nn_grad.py
浏览文件 @
df52532c
...
...
@@ -322,18 +322,33 @@ def _BroadcastMul(vec, mat):
@
ops
.
RegisterGradient
(
"SoftmaxCrossEntropyWithLogits"
)
def
_SoftmaxCrossEntropyWithLogitsGrad
(
op
,
grad_0
,
_
):
"""Gradient function for SoftmaxCrossEntropyWithLogits."""
# grad_0 is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# There is no gradient for the labels
return
_BroadcastMul
(
grad_0
,
op
.
outputs
[
1
]),
None
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
softmax_grad_without_gradient
=
array_ops
.
prevent_gradient
(
op
.
outputs
[
1
])
return
_BroadcastMul
(
grad_0
,
softmax_grad_without_gradient
),
None
@
ops
.
RegisterGradient
(
"SparseSoftmaxCrossEntropyWithLogits"
)
def
_SparseSoftmaxCrossEntropyWithLogitsGrad
(
op
,
grad_0
,
_
):
"""Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
# grad_0 is the backprop for cost, and we multiply it with the gradients
# (which is output[1])
# There is no gradient for the labels
return
_BroadcastMul
(
grad_0
,
op
.
outputs
[
1
]),
None
#
# Currently there is no way to take the second derivative of this op
# due to the fused implementation's interaction with tf.gradients(),
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient
=
array_ops
.
prevent_gradient
(
op
.
outputs
[
1
])
return
_BroadcastMul
(
grad_0
,
sparse_softmax_grad_without_gradient
),
None
@
ops
.
RegisterGradient
(
"Conv2D"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录