Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_51669992
tensorflow
提交
f6a8bd5d
T
tensorflow
项目概览
weixin_51669992
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
16
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f6a8bd5d
编写于
8月 12, 2019
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
8月 12, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated rollback of commit
a216c03f
PiperOrigin-RevId: 262893594
上级
ccb8c64f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
131 addition
and
83 deletion
+131
-83
tensorflow/python/kernel_tests/gather_op_test.py
tensorflow/python/kernel_tests/gather_op_test.py
+29
-25
tensorflow/python/ops/array_grad.py
tensorflow/python/ops/array_grad.py
+95
-34
tensorflow/python/ops/array_ops.py
tensorflow/python/ops/array_ops.py
+7
-24
未找到文件。
tensorflow/python/kernel_tests/gather_op_test.py
浏览文件 @
f6a8bd5d
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
from
absl.testing
import
parameterized
import
numpy
as
np
from
tensorflow.python.
compat
import
compat
from
tensorflow.python.
eager
import
backprop
from
tensorflow.python.eager
import
context
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
...
...
@@ -351,19 +351,31 @@ class GatherTest(test.TestCase, parameterized.TestCase):
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
self
.
assertAllEqual
(
expected
,
result
)
with
compat
.
forward_compatibility_horizon
(
2019
,
9
,
11
):
# Test the gradients shape.
if
context
.
executing_eagerly
():
with
backprop
.
GradientTape
()
as
tape
:
zeros
=
array_ops
.
zeros_like
(
params
,
dtype
=
dtypes
.
float32
)
tape
.
watch
(
zeros
)
values
=
zeros
*
2
+
zeros
result
=
array_ops
.
gather
(
values
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
gradients
=
tape
.
gradient
(
result
,
zeros
)
else
:
zeros
=
array_ops
.
zeros_like
(
params
,
dtype
=
dtypes
.
float32
)
values
=
zeros
*
2
+
zeros
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
values
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
gradients
=
gradients_impl
.
gradients
(
result
,
[
zeros
])[
0
]
self
.
assertAllEqual
(
expected
,
result
)
self
.
assertAllEqual
(
array_ops
.
shape
(
params
),
array_ops
.
shape
(
gradients
)
)
# Run the same test for strings.
params
=
_to_str_elements
(
params
)
expected
=
_to_str_elements
(
expected
)
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
# Run the same test for strings.
params
=
_to_str_elements
(
params
)
expected
=
_to_str_elements
(
expected
)
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
self
.
assertAllEqual
(
expected
,
result
)
self
.
assertAllEqual
(
expected
,
result
)
@
parameterized
.
parameters
([
dict
(
...
...
@@ -459,22 +471,14 @@ class GatherTest(test.TestCase, parameterized.TestCase):
self
.
assertAllEqual
(
output_shape
,
result
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected
,
result
)
with
compat
.
forward_compatibility_horizon
(
2019
,
9
,
11
):
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
self
.
assertAllEqual
(
output_shape
,
result
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected
,
result
)
# Run the same test for strings.
params
=
_to_str_elements
(
params
)
expected
=
_to_str_elements
(
expected
.
tolist
())
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
self
.
assertAllEqual
(
output_shape
,
result
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected
,
result
)
# Run the same test for strings.
params
=
_to_str_elements
(
params
)
expected
=
_to_str_elements
(
expected
.
tolist
())
result
=
array_ops
.
gather
(
params
,
indices
,
axis
=
axis
,
batch_dims
=
batch_dims
)
self
.
assertAllEqual
(
output_shape
,
result
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected
,
result
)
def
_batchNumpyGather
(
self
,
params
,
indices
,
axis
,
batch_dims
):
"""Performs a batch gather by making recursive calls to np.take().
...
...
tensorflow/python/ops/array_grad.py
浏览文件 @
f6a8bd5d
...
...
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from
tensorflow.python.ops
import
control_flow_ops
from
tensorflow.python.ops
import
control_flow_util
from
tensorflow.python.ops
import
gen_array_ops
from
tensorflow.python.ops
import
gen_math_ops
from
tensorflow.python.ops
import
gen_resource_variable_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops
import
sparse_ops
...
...
@@ -477,6 +478,61 @@ def _GatherGrad(op, grad):
return
[
ops
.
IndexedSlices
(
values
,
indices
,
params_shape
),
None
]
def
_GetBatchIndices
(
params_shape
,
indices
,
batch_dims
):
"""Addds the batch offsets to the given indices and returns the results."""
batch_indices
=
indices
indices_ndims
=
indices
.
shape
.
ndims
indices_dtype
=
indices
.
dtype
.
base_dtype
casted_params_shape
=
math_ops
.
cast
(
params_shape
,
indices_dtype
)
accum_dim_value
=
array_ops
.
ones
((),
dtype
=
indices_dtype
)
for
dim
in
range
(
batch_dims
,
0
,
-
1
):
dim_value
=
casted_params_shape
[
dim
-
1
]
accum_dim_value
*=
casted_params_shape
[
dim
]
start
=
array_ops
.
zeros
((),
dtype
=
indices_dtype
)
step
=
array_ops
.
ones
((),
dtype
=
indices_dtype
)
dim_indices
=
math_ops
.
range
(
start
,
dim_value
,
step
)
dim_indices
*=
accum_dim_value
dim_shape
=
array_ops
.
stack
(
[
1
]
*
(
dim
-
1
)
+
[
dim_value
]
+
[
1
]
*
(
indices_ndims
-
dim
),
axis
=
0
)
batch_indices
+=
array_ops
.
reshape
(
dim_indices
,
dim_shape
)
return
batch_indices
def
_BatchGatherGrad
(
params_shape
,
values
,
indices
,
batch_dims
,
gather_dim_size
):
"""Returns the gradient of GatherV2 with batch dimensions."""
# Axis is the first non-batch dimension.
indices_size
=
array_ops
.
expand_dims
(
array_ops
.
size
(
indices
),
0
)
if
batch_dims
:
values_shape
=
array_ops
.
shape
(
values
)
# Add the batch offsets to indices and flatten the batch dimensions.
outer_shape
=
values_shape
[:
batch_dims
]
inner_shape
=
values_shape
[
batch_dims
:][
1
:]
batch_size
=
gen_math_ops
.
prod
(
outer_shape
,
[
0
],
False
)
flat_values_shape
=
array_ops
.
concat
([[
-
1
],
inner_shape
],
0
)
gather_dim_size
*=
batch_size
indices
=
_GetBatchIndices
(
params_shape
,
indices
,
batch_dims
)
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
"ignore"
,
message
=
"Converting sparse IndexedSlices to a dense Tensor.*"
)
values
=
array_ops
.
reshape
(
values
,
flat_values_shape
)
indices
=
array_ops
.
reshape
(
indices
,
indices_size
)
params_grad
=
math_ops
.
unsorted_segment_sum
(
values
,
indices
,
gather_dim_size
)
if
batch_dims
:
# Put back the batch dimensions.
params_grad
=
array_ops
.
reshape
(
params_grad
,
array_ops
.
concat
([
outer_shape
,
flat_values_shape
],
0
))
return
params_grad
@
ops
.
RegisterGradient
(
"GatherV2"
)
def
_GatherV2Grad
(
op
,
grad
):
"""Gradient for GatherV2 op."""
...
...
@@ -495,6 +551,10 @@ def _GatherV2Grad(op, grad):
indices_size
=
array_ops
.
expand_dims
(
array_ops
.
size
(
indices
),
0
)
axis
=
op
.
inputs
[
2
]
axis_static
=
tensor_util
.
constant_value
(
axis
)
batch_dims
=
int
(
op
.
get_attr
(
"batch_dims"
))
if
batch_dims
<
0
:
batch_dims
+=
indices
.
shape
.
ndims
# For axis 0 gathers, build an appropriately shaped IndexedSlices.
if
axis_static
==
0
:
...
...
@@ -509,44 +569,45 @@ def _GatherV2Grad(op, grad):
message
=
"Converting sparse IndexedSlices to a dense Tensor.*"
)
values
=
array_ops
.
reshape
(
grad
,
values_shape
)
indices
=
array_ops
.
reshape
(
indices
,
indices_size
)
return
[
ops
.
IndexedSlices
(
values
,
indices
,
params_shape
),
None
,
None
]
params_grad
=
ops
.
IndexedSlices
(
values
,
indices
,
params_shape
)
else
:
# Handle axis by transposing the axis dimension to be the first non-batch
# dimension, compute the gradiend and transpose the result back.
outer_shape
=
params_shape
[:
axis
]
inner_shape
=
params_shape
[
axis
:][
1
:]
values_shape
=
array_ops
.
concat
([
outer_shape
,
[
-
1
],
inner_shape
],
0
)
outer_shape
=
params_shape
[:
axis
]
outer_dims
=
array_ops
.
size
(
outer_shape
)
inner_shape
=
params_shape
[
axis
:][
1
:]
inner_dims
=
array_ops
.
size
(
inner_shape
)
values_dims
=
array_ops
.
size
(
values_shape
)
axis_dims
=
array_ops
.
size
(
outer_shape
)
outer_axes_indices
=
math_ops
.
range
(
outer
_dims
)
inner_axes_indices
=
math_ops
.
range
(
outer_dims
+
1
,
outer_dims
+
1
+
inner
_dims
)
outer_batches_indices
=
math_ops
.
range
(
batch
_dims
)
batch_axis_indices
=
math_ops
.
range
(
batch_dims
,
axis_dims
)
inner_axes_indices
=
math_ops
.
range
(
axis_dims
+
1
,
values
_dims
)
values_shape
=
array_ops
.
concat
([
outer_shape
,
indices_size
,
inner_shape
],
0
)
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
"ignore"
,
message
=
"Converting sparse IndexedSlices to a dense Tensor.*"
)
values
=
array_ops
.
reshape
(
grad
,
values_shape
)
indices
=
array_ops
.
reshape
(
indices
,
indices_size
)
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
"ignore"
,
message
=
"Converting sparse IndexedSlices to a dense Tensor.*"
)
values
=
array_ops
.
reshape
(
grad
,
values_shape
)
# Move values[axis] up to values[batch_dims]
transpose_dims
=
array_ops
.
concat
(
[
outer_batches_indices
,
[
axis_dims
],
batch_axis_indices
,
inner_axes_indices
],
0
)
values_transpose
=
array_ops
.
transpose
(
values
,
transpose_dims
)
params_grad
=
_BatchGatherGrad
(
params_shape
,
values_transpose
,
indices
,
batch_dims
,
params_shape
[
axis
])
# Inverts the above transpose by moving dimension batch_dims back to its
# original position.
invert_transpose_dims
=
array_ops
.
concat
(
[
outer_batches_indices
,
batch_axis_indices
+
1
,
[
batch_dims
],
inner_axes_indices
],
0
)
params_grad
=
array_ops
.
transpose
(
params_grad
,
invert_transpose_dims
)
# We need to sum up every slice `values[..., i, ....]` corresponding to
# `params[..., indices[i], ...]`. Since `unsorted_segment_sum` does not
# support an axis parameter, we transpose the gather dimension to the front,
# then use `unsorted_segment_sum` to build a
# [gather_axis, outer_axes, inner_axes] tensor with all the gradients
# affecting each index in `gather_axis` summed up.
transpose_dims
=
array_ops
.
concat
(
[[
outer_dims
],
outer_axes_indices
,
inner_axes_indices
],
0
)
values_transpose
=
array_ops
.
transpose
(
values
,
transpose_dims
)
num_segments
=
params_shape
[
axis
]
params_grad
=
math_ops
.
unsorted_segment_sum
(
values_transpose
,
indices
,
num_segments
)
# Inverts the above transpose by moving dimension 0 back to its original
# position.
invert_transpose_dims
=
array_ops
.
concat
(
[
outer_axes_indices
+
1
,
[
0
],
inner_axes_indices
],
0
)
params_grad
=
array_ops
.
transpose
(
params_grad
,
invert_transpose_dims
)
return
[
params_grad
,
None
,
None
]
...
...
tensorflow/python/ops/array_ops.py
浏览文件 @
f6a8bd5d
...
...
@@ -3952,36 +3952,19 @@ def gather(params,
A `Tensor`. Has the same type as `params`.
"""
del
validate_indices
if
compat
.
forward_compatible
(
2019
,
9
,
10
):
if
axis
is
None
:
axis
=
batch_dims
if
axis
!=
0
:
return
gen_array_ops
.
gather_v2
(
params
,
indices
,
axis
,
batch_dims
=
batch_dims
,
name
=
name
)
try
:
# TODO(apassos) find a less bad way of detecting resource variables
# without introducing a circular dependency.
return
params
.
sparse_read
(
indices
,
name
=
name
)
except
AttributeError
:
return
gen_array_ops
.
gather_v2
(
params
,
indices
,
axis
,
name
=
name
)
if
batch_dims
!=
0
:
with
ops
.
name_scope
(
name
,
"Gather"
,
[
params
,
indices
,
axis
]):
return
_batch_gather
(
params
,
indices
,
batch_dims
,
axis
)
if
axis
is
None
:
axis
=
batch_dims
if
axis
!=
0
:
# Note that we do a sparse_read here to avoid snapshotting the entire
# resource variable and doing a gather, which can be inefficient and lead to
# subtle race conditions. TODO(apassos) implement axis != 0 on sparse_read
return
gen_array_ops
.
gather_v2
(
params
,
indices
,
axis
,
name
=
name
)
return
gen_array_ops
.
gather_v2
(
params
,
indices
,
axis
,
batch_dims
=
batch_dims
,
name
=
name
)
try
:
# TODO(apassos) find a less bad way of detecting resource variables
without
# introducing a circular dependency.
# TODO(apassos) find a less bad way of detecting resource variables
#
without
introducing a circular dependency.
return
params
.
sparse_read
(
indices
,
name
=
name
)
except
AttributeError
:
return
gen_array_ops
.
gather_v2
(
params
,
indices
,
axis
,
name
=
name
)
return
gen_array_ops
.
gather_v2
(
params
,
indices
,
axis
,
name
=
name
)
@
tf_export
(
"gather"
,
v1
=
[])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录