Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
49760690
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
49760690
编写于
2月 24, 2016
作者:
G
Geoffrey Irving
提交者:
TensorFlower Gardener
2月 24, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix build issue with safety fix to gather and scatter
Change: 115495726
上级
746ccc84
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
205 addition
and
123 deletion
+205
-123
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/BUILD
+12
-0
tensorflow/core/kernels/bounds_check.h
tensorflow/core/kernels/bounds_check.h
+38
-0
tensorflow/core/kernels/gather_op.cc
tensorflow/core/kernels/gather_op.cc
+78
-73
tensorflow/core/kernels/scatter_op.cc
tensorflow/core/kernels/scatter_op.cc
+57
-40
tensorflow/core/kernels/scatter_op.h
tensorflow/core/kernels/scatter_op.h
+5
-4
tensorflow/core/kernels/scatter_op_gpu.cu.cc
tensorflow/core/kernels/scatter_op_gpu.cu.cc
+5
-4
tensorflow/python/kernel_tests/gather_op_test.py
tensorflow/python/kernel_tests/gather_op_test.py
+8
-0
tensorflow/python/kernel_tests/scatter_ops_test.py
tensorflow/python/kernel_tests/scatter_ops_test.py
+2
-2
未找到文件。
tensorflow/core/kernels/BUILD
浏览文件 @
49760690
...
@@ -30,6 +30,15 @@ cc_library(
...
@@ -30,6 +30,15 @@ cc_library(
],
],
)
)
cc_library
(
name
=
"bounds_check"
,
hdrs
=
[
"bounds_check.h"
],
deps
=
[
"//tensorflow/core:framework"
,
"//third_party/eigen3"
,
],
)
tf_kernel_library
(
tf_kernel_library
(
name
=
"concat_lib"
,
name
=
"concat_lib"
,
srcs
=
[
"concat_lib_cpu.cc"
],
srcs
=
[
"concat_lib_cpu.cc"
],
...
@@ -226,6 +235,7 @@ tf_kernel_libraries(
...
@@ -226,6 +235,7 @@ tf_kernel_libraries(
"where_op"
,
"where_op"
,
],
],
deps
=
[
deps
=
[
":bounds_check"
,
":concat_lib"
,
":concat_lib"
,
":fill_functor"
,
":fill_functor"
,
":ops_util"
,
":ops_util"
,
...
@@ -874,6 +884,7 @@ tf_kernel_libraries(
...
@@ -874,6 +884,7 @@ tf_kernel_libraries(
],
],
deps
=
[
deps
=
[
":assign_op"
,
":assign_op"
,
":bounds_check"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:state_ops_op_lib"
,
"//tensorflow/core:state_ops_op_lib"
,
...
@@ -955,6 +966,7 @@ filegroup(
...
@@ -955,6 +966,7 @@ filegroup(
"assign_op.h"
,
"assign_op.h"
,
"bias_op.cc"
,
"bias_op.cc"
,
"bias_op.h"
,
"bias_op.h"
,
"bounds_check.h"
,
"cast_op.cc"
,
"cast_op.cc"
,
"cast_op.h"
,
"cast_op.h"
,
"concat_lib.h"
,
"concat_lib.h"
,
...
...
tensorflow/core/kernels/bounds_check.h
0 → 100644
浏览文件 @
49760690
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#include <type_traits>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/macros.h"
namespace
tensorflow
{
// Check that 0 <= index < limit using a single comparison, assuming
// that 0 <= limit if Index is signed. Intended for use in performance
// critical contexts where 0 <= index < limit is almost always true.
template
<
class
Index
>
EIGEN_ALWAYS_INLINE
bool
FastBoundsCheck
(
Index
index
,
Index
limit
)
{
typedef
typename
std
::
make_unsigned
<
Index
>::
type
UIndex
;
return
TF_PREDICT_TRUE
(
static_cast
<
UIndex
>
(
index
)
<
static_cast
<
UIndex
>
(
limit
));
}
}
// namespace tensorflow
#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
tensorflow/core/kernels/gather_op.cc
浏览文件 @
49760690
...
@@ -18,36 +18,52 @@ limitations under the License.
...
@@ -18,36 +18,52 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
{
namespace
{
// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
template
<
typename
T
,
typename
Index
,
int
static_slice_elems
>
template
<
typename
T
,
typename
Index
,
int
static_slice_elems
>
void
HandleCopies
(
const
Tensor
&
T
params
,
Index
HandleCopies
(
const
Tensor
&
params
,
typename
TTypes
<
Index
>::
ConstVec
&
Tindices
,
int
slice_elems
,
typename
TTypes
<
Index
>::
ConstVec
indices
,
Index
slice_elems
,
typename
TTypes
<
T
>::
Matrix
T
out
)
{
typename
TTypes
<
T
>::
Matrix
out
)
{
const
int
N
=
T
indices
.
dimension
(
0
);
const
int
N
=
indices
.
dimension
(
0
);
const
auto
&
Tparams_flat
=
T
params
.
flat_outer_dims
<
T
>
();
const
auto
&
params_flat
=
params
.
flat_outer_dims
<
T
>
();
T
*
Tout_base
=
&
Tout
(
0
,
0
);
const
Index
limit
=
params
.
dim_size
(
0
);
const
T
*
Tparams_base
=
&
Tparams_fla
t
(
0
,
0
);
T
*
out_base
=
&
ou
t
(
0
,
0
);
const
size_t
slice_bytes
=
slice_elems
*
sizeof
(
T
);
const
T
*
params_base
=
&
params_flat
(
0
,
0
);
if
(
static_slice_elems
>=
0
)
{
if
(
static_slice_elems
>=
0
)
{
// Give compiler static knowledge of the number of elements/bytes
// Give compiler static knowledge of the number of elements/bytes
CHECK_EQ
(
static_slice_elems
,
slice_elems
);
CHECK_EQ
(
static_slice_elems
,
slice_elems
);
slice_elems
=
static_slice_elems
;
slice_elems
=
static_slice_elems
;
}
}
// Compute slice_bytes here so that static knowledge is available
const
size_t
slice_bytes
=
slice_elems
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
int
j
=
i
+
1
;
const
int
j
=
i
+
1
;
if
(
j
<
N
)
{
if
(
j
<
N
)
{
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tparams_flat
(
Tindices
(
j
),
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
params_flat
(
indices
(
j
),
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tout
(
j
,
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
out
(
j
,
0
));
}
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const
Index
index
=
indices
(
i
);
if
(
!
FastBoundsCheck
(
index
,
limit
))
return
i
;
// Copy using memcpy if possible, otherwise an Eigen loop
if
(
Allocator
::
is_simple
<
T
>::
value
)
{
memcpy
(
out_base
+
i
*
slice_elems
,
params_base
+
index
*
slice_elems
,
slice_bytes
);
}
else
{
out
.
template
chip
<
0
>(
i
)
=
params_flat
.
template
chip
<
0
>(
index
);
}
}
memcpy
(
Tout_base
+
i
*
slice_elems
,
Tparams_base
+
Tindices
(
i
)
*
slice_elems
,
slice_bytes
);
}
}
return
-
1
;
}
}
}
// anonymous namespace
}
// anonymous namespace
...
@@ -64,78 +80,67 @@ class GatherOp : public OpKernel {
...
@@ -64,78 +80,67 @@ class GatherOp : public OpKernel {
const
DataType
dt
=
DataTypeToEnum
<
T
>::
v
();
const
DataType
dt
=
DataTypeToEnum
<
T
>::
v
();
const
DataType
index_t
=
DataTypeToEnum
<
Index
>::
v
();
const
DataType
index_t
=
DataTypeToEnum
<
Index
>::
v
();
OP_REQUIRES_OK
(
c
,
c
->
MatchSignature
({
dt
,
index_t
},
{
dt
}));
OP_REQUIRES_OK
(
c
,
c
->
MatchSignature
({
dt
,
index_t
},
{
dt
}));
OP_REQUIRES_OK
(
c
,
c
->
GetAttr
(
"validate_indices"
,
&
validate_indices_
));
// We used to grab the validate_indices attribute here, but now we
// always validate indices since the speed difference was only 1.5%.
// TODO(irving): Remove the validate_indices attribute once we have
// support for removing attrs in a backwards compatible way.
}
}
void
Compute
(
OpKernelContext
*
c
)
override
{
void
Compute
(
OpKernelContext
*
c
)
override
{
const
Tensor
&
T
params
=
c
->
input
(
0
);
const
Tensor
&
params
=
c
->
input
(
0
);
const
Tensor
&
T
indices
=
c
->
input
(
1
);
const
Tensor
&
indices
=
c
->
input
(
1
);
OP_REQUIRES
(
OP_REQUIRES
(
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
T
params
.
shape
()),
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
params
.
shape
()),
errors
::
InvalidArgument
(
"params must be at least 1 dimensional"
));
errors
::
InvalidArgument
(
"params must be at least 1 dimensional"
));
const
int64
N
=
Tindices
.
NumElements
();
const
int64
first_dim_size
=
Tparams
.
dim_size
(
0
);
// Check that we have enough index space
const
int64
N_big
=
indices
.
NumElements
();
// Validate all the indices are in range
OP_REQUIRES
(
c
,
N_big
<=
std
::
numeric_limits
<
int
>::
max
(),
auto
Tindices_vec
=
Tindices
.
flat
<
Index
>
();
errors
::
InvalidArgument
(
if
(
validate_indices_
)
{
"indices has too many elements for int indexing: "
,
N_big
,
for
(
int64
i
=
0
;
i
<
N
;
i
++
)
{
" > "
,
std
::
numeric_limits
<
int
>::
max
()));
const
Index
index
=
Tindices_vec
(
i
);
const
int
N
=
indices
.
NumElements
(
);
OP_REQUIRES
(
c
,
index
>=
0
&&
index
<
first_dim_size
,
OP_REQUIRES
(
errors
::
InvalidArgument
(
c
,
params
.
dim_size
(
0
)
<=
std
::
numeric_limits
<
Index
>::
max
(),
strings
::
StrCat
(
"Index "
,
index
,
" at offset "
,
i
,
errors
::
InvalidArgument
(
"params.shape[0] too large for "
,
" in Tindices is out of range"
)));
DataTypeString
(
DataTypeToEnum
<
Index
>::
v
()),
}
" indexing: "
,
params
.
dim_size
(
0
),
" > "
,
}
std
::
numeric_limits
<
Index
>::
max
()));
// The result shape is indices.shape + params.shape[1:].
// The result shape is indices.shape + params.shape[1:].
TensorShape
result_shape
=
T
indices
.
shape
();
TensorShape
result_shape
=
indices
.
shape
();
for
(
int
i
=
1
;
i
<
T
params
.
dims
();
i
++
)
{
for
(
int
i
=
1
;
i
<
params
.
dims
();
i
++
)
{
result_shape
.
AddDim
(
T
params
.
dim_size
(
i
));
result_shape
.
AddDim
(
params
.
dim_size
(
i
));
}
}
Tensor
*
Tout
=
nullptr
;
Tensor
*
out
=
nullptr
;
OP_REQUIRES_OK
(
c
,
c
->
allocate_output
(
0
,
result_shape
,
&
Tout
));
OP_REQUIRES_OK
(
c
,
c
->
allocate_output
(
0
,
result_shape
,
&
out
));
const
auto
&
Tparams_flat
=
Tparams
.
flat_outer_dims
<
T
>
();
if
(
N
>
0
)
{
if
(
N
>
0
)
{
auto
Tindices_flat
=
Tindices
.
flat
<
Index
>
();
auto
indices_flat
=
indices
.
flat
<
Index
>
();
auto
Tout_flat
=
Tout
->
shaped
<
T
,
2
>
({
N
,
Tout
->
NumElements
()
/
N
});
auto
out_flat
=
out
->
shaped
<
T
,
2
>
({
N
,
out
->
NumElements
()
/
N
});
if
(
DataTypeCanUseMemcpy
(
DataTypeToEnum
<
T
>::
v
()))
{
const
int64
slice_size
=
out
->
NumElements
()
/
N
;
const
int64
slice_size
=
Tout
->
NumElements
()
/
N
;
Index
bad_i
;
#define SPECIALIZE(elems) \
do { \
#define CALL(elems) \
if (slice_size == elems) { \
bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \
out_flat)
Tout_flat); \
return; \
if
(
slice_size
==
10
)
} \
CALL
(
10
);
} while (0)
else
if
(
slice_size
==
20
)
CALL
(
20
);
SPECIALIZE
(
10
);
else
SPECIALIZE
(
20
);
CALL
(
-
1
);
#undef SPECIALIZE
#undef CALL
HandleCopies
<
T
,
Index
,
-
1
>
(
Tparams
,
Tindices_flat
,
slice_size
,
OP_REQUIRES
(
Tout_flat
);
c
,
bad_i
<
0
,
}
else
{
errors
::
InvalidArgument
(
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
"indices"
,
SliceDebugString
(
indices
.
shape
(),
bad_i
),
" = "
,
int
j
=
i
+
1
;
indices_flat
(
bad_i
),
" is not in [0, "
,
params
.
dim_size
(
0
),
")"
));
if
(
j
<
N
)
{
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tparams_flat
(
Tindices_vec
(
j
),
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tout_flat
(
j
,
0
));
}
// Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
Tout_flat
.
template
chip
<
0
>(
i
)
=
Tparams_flat
.
template
chip
<
0
>(
Tindices_vec
(
i
));
}
}
}
}
}
}
private:
bool
validate_indices_
;
};
};
#define REGISTER_GATHER(type, index_type) \
#define REGISTER_GATHER(type, index_type) \
...
...
tensorflow/core/kernels/scatter_op.cc
浏览文件 @
49760690
...
@@ -20,8 +20,10 @@ limitations under the License.
...
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace
tensorflow
{
namespace
tensorflow
{
...
@@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel {
...
@@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel {
}
}
void
DoCompute
(
OpKernelContext
*
c
)
{
void
DoCompute
(
OpKernelContext
*
c
)
{
Tensor
T
params
=
c
->
mutable_input
(
0
,
use_exclusive_lock_
);
Tensor
params
=
c
->
mutable_input
(
0
,
use_exclusive_lock_
);
OP_REQUIRES
(
c
,
T
params
.
IsInitialized
(),
OP_REQUIRES
(
c
,
params
.
IsInitialized
(),
errors
::
FailedPrecondition
(
"Null ref for params"
));
errors
::
FailedPrecondition
(
"Null ref for params"
));
const
Tensor
&
T
indices
=
c
->
input
(
1
);
const
Tensor
&
indices
=
c
->
input
(
1
);
const
Tensor
&
T
updates
=
c
->
input
(
2
);
const
Tensor
&
updates
=
c
->
input
(
2
);
OP_REQUIRES
(
OP_REQUIRES
(
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
T
params
.
shape
()),
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
params
.
shape
()),
errors
::
InvalidArgument
(
"params must be at least 1-D, got shape "
,
errors
::
InvalidArgument
(
"params must be at least 1-D, got shape "
,
T
params
.
shape
().
DebugString
()));
params
.
shape
().
DebugString
()));
OP_REQUIRES
(
OP_REQUIRES
(
c
,
ValidShapes
(
Tparams
,
Tupdates
,
T
indices
),
c
,
ValidShapes
(
params
,
updates
,
indices
),
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Must have updates.shape = indices.shape + params.shape[1:], got "
,
"Must have updates.shape = indices.shape + params.shape[1:], got "
,
"updates.shape "
,
Tupdates
.
shape
().
DebugString
(),
"updates.shape "
,
updates
.
shape
().
DebugString
(),
", indices.shape "
,
", indices.shape "
,
Tindices
.
shape
().
DebugString
(),
indices
.
shape
().
DebugString
(),
", params.shape "
,
", params.shape "
,
Tparams
.
shape
().
DebugString
()));
params
.
shape
().
DebugString
()));
// Check that we have enough index space
const
int64
N_big
=
indices
.
NumElements
();
OP_REQUIRES
(
c
,
N_big
<=
std
::
numeric_limits
<
Index
>::
max
(),
errors
::
InvalidArgument
(
"indices has too many elements for "
,
DataTypeString
(
DataTypeToEnum
<
Index
>::
v
()),
" indexing: "
,
N_big
,
" > "
,
std
::
numeric_limits
<
Index
>::
max
()));
const
Index
N
=
indices
.
NumElements
();
OP_REQUIRES
(
c
,
params
.
dim_size
(
0
)
<=
std
::
numeric_limits
<
Index
>::
max
(),
errors
::
InvalidArgument
(
"params.shape[0] too large for "
,
DataTypeString
(
DataTypeToEnum
<
Index
>::
v
()),
" indexing: "
,
params
.
dim_size
(
0
),
" > "
,
std
::
numeric_limits
<
Index
>::
max
()));
// We always return the input ref.
// We always return the input ref.
c
->
forward_ref_input_to_ref_output
(
0
,
0
);
c
->
forward_ref_input_to_ref_output
(
0
,
0
);
const
Index
N
=
Tindices
.
NumElements
();
if
(
N
>
0
)
{
if
(
N
>
0
)
{
auto
Tindices_flat
=
Tindices
.
flat
<
Index
>
();
auto
indices_flat
=
indices
.
flat
<
Index
>
();
auto
Tparams_flat
=
Tparams
.
flat_outer_dims
<
T
>
();
auto
params_flat
=
params
.
flat_outer_dims
<
T
>
();
auto
Tupdates_flat
=
auto
updates_flat
=
updates
.
shaped
<
T
,
2
>
({
N
,
updates
.
NumElements
()
/
N
});
Tupdates
.
shaped
<
T
,
2
>
({
N
,
Tupdates
.
NumElements
()
/
N
});
functor
::
ScatterFunctor
<
Device
,
T
,
Index
,
op
>
functor
;
functor
::
ScatterFunctor
<
Device
,
T
,
Index
,
op
>
functor
;
functor
(
c
,
c
->
template
eigen_device
<
Device
>(),
const
Index
bad_i
=
functor
(
c
,
c
->
template
eigen_device
<
Device
>(),
Tparams_flat
,
Tupdates_flat
,
Tindices_flat
);
params_flat
,
updates_flat
,
indices_flat
);
OP_REQUIRES
(
c
,
bad_i
<
0
,
errors
::
InvalidArgument
(
"indices"
,
SliceDebugString
(
indices
.
shape
(),
bad_i
),
" = "
,
indices_flat
(
bad_i
),
" is not in [0, "
,
params
.
dim_size
(
0
),
")"
));
}
}
}
}
};
};
...
@@ -137,26 +157,23 @@ namespace functor {
...
@@ -137,26 +157,23 @@ namespace functor {
// Implementation of update functor for CPU.
// Implementation of update functor for CPU.
template
<
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
template
<
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
struct
ScatterFunctor
<
CPUDevice
,
T
,
Index
,
op
>
{
struct
ScatterFunctor
<
CPUDevice
,
T
,
Index
,
op
>
{
void
operator
()(
OpKernelContext
*
c
,
const
CPUDevice
&
d
,
Index
operator
()(
OpKernelContext
*
c
,
const
CPUDevice
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
Index
N
=
indices
.
size
();
const
Index
N
=
indices
.
size
();
// Validate all the indices are in range
const
Index
limit
=
params
.
dimension
(
0
);
Index
first_dim_size
=
params
.
dimension
(
0
);
for
(
Index
i
=
0
;
i
<
N
;
i
++
)
{
for
(
Index
i
=
0
;
i
<
N
;
i
++
)
{
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const
Index
index
=
indices
(
i
);
const
Index
index
=
indices
(
i
);
OP_REQUIRES
(
c
,
index
>=
0
&&
index
<
first_dim_size
,
if
(
!
FastBoundsCheck
(
index
,
limit
))
return
i
;
errors
::
InvalidArgument
(
// Copy last Ndim-1 dimensions of updates[i] to params[index]
strings
::
StrCat
(
"Index "
,
index
,
" at offset "
,
i
,
Assign
<
op
>::
Run
(
params
.
template
chip
<
0
>(
index
),
" in indices is out of range"
)));
}
for
(
Index
i
=
0
;
i
<
N
;
i
++
)
{
// Copy last Ndim-1 dimensions of Tupdates[i] to
// Tparams[Tindices[i]]
Assign
<
op
>::
Run
(
params
.
template
chip
<
0
>(
indices
(
i
)),
updates
.
template
chip
<
0
>(
i
));
updates
.
template
chip
<
0
>(
i
));
}
}
return
-
1
;
}
}
};
};
}
// namespace functor
}
// namespace functor
...
@@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
...
@@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
// Forward declarations of the functor specializations for GPU.
// Forward declarations of the functor specializations for GPU.
namespace
functor
{
namespace
functor
{
#define DECLARE_GPU_SPECS_OP(T, Index, op)
\
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
template <>
\
template <> \
void ScatterFunctor<GPUDevice, T, Index, op>::operator()(
\
Index ScatterFunctor<GPUDevice, T, Index, op>::operator()(
\
OpKernelContext* c, const GPUDevice& d,
\
OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T>::Matrix params,
\
typename TTypes<T>::Matrix params, \
typename TTypes<T>::ConstMatrix updates,
\
typename TTypes<T>::ConstMatrix updates, \
typename TTypes<Index>::ConstFlat indices);
\
typename TTypes<Index>::ConstFlat indices); \
extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
...
...
tensorflow/core/kernels/scatter_op.h
浏览文件 @
49760690
...
@@ -36,10 +36,11 @@ namespace functor {
...
@@ -36,10 +36,11 @@ namespace functor {
// Functor used by ScatterOp to do the computations.
// Functor used by ScatterOp to do the computations.
template
<
typename
Device
,
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
template
<
typename
Device
,
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
struct
ScatterFunctor
{
struct
ScatterFunctor
{
void
operator
()(
OpKernelContext
*
c
,
const
Device
&
d
,
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
typename
TTypes
<
T
>::
Matrix
params
,
Index
operator
()(
OpKernelContext
*
c
,
const
Device
&
d
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
Index
>::
ConstFlat
indices
);
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
);
};
};
}
// namespace functor
}
// namespace functor
...
...
tensorflow/core/kernels/scatter_op_gpu.cu.cc
浏览文件 @
49760690
...
@@ -62,10 +62,10 @@ namespace functor {
...
@@ -62,10 +62,10 @@ namespace functor {
// Specialization for a GPU device.
// Specialization for a GPU device.
template
<
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
template
<
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
struct
ScatterFunctor
<
GPUDevice
,
T
,
Index
,
op
>
{
struct
ScatterFunctor
<
GPUDevice
,
T
,
Index
,
op
>
{
void
operator
()(
OpKernelContext
*
c
,
const
GPUDevice
&
d
,
Index
operator
()(
OpKernelContext
*
c
,
const
GPUDevice
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
// TODO: Implement indices range check. The hardest part is with returning
// TODO: Implement indices range check. The hardest part is with returning
// a value after the range check, as we do not want to do device to host
// a value after the range check, as we do not want to do device to host
// memcpy during a stream.
// memcpy during a stream.
...
@@ -77,6 +77,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
...
@@ -77,6 +77,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
<<<
config
.
block_count
,
config
.
thread_per_block
,
0
,
d
.
stream
()
>>>
(
<<<
config
.
block_count
,
config
.
thread_per_block
,
0
,
d
.
stream
()
>>>
(
params
.
data
(),
updates
.
data
(),
indices
.
data
(),
params
.
data
(),
updates
.
data
(),
indices
.
data
(),
first_dim_size
,
updates_size
,
indices_size
);
first_dim_size
,
updates_size
,
indices_size
);
return
-
1
;
}
}
};
};
...
...
tensorflow/python/kernel_tests/gather_op_test.py
浏览文件 @
49760690
...
@@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase):
...
@@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase):
gather_t
=
tf
.
gather
(
params
,
indices
)
gather_t
=
tf
.
gather
(
params
,
indices
)
self
.
assertEqual
(
None
,
gather_t
.
get_shape
())
self
.
assertEqual
(
None
,
gather_t
.
get_shape
())
def
testBadIndices
(
self
):
with
self
.
test_session
():
params
=
[
0
,
1
,
2
]
indices
=
[[
7
]]
gather
=
tf
.
gather
(
params
,
indices
)
with
self
.
assertRaisesOpError
(
r
"indices\[0,0\] = 7 is not in \[0, 3\)"
):
gather
.
eval
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
tensorflow/python/kernel_tests/scatter_ops_test.py
浏览文件 @
49760690
...
@@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
...
@@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
# Test some out of range errors.
# Test some out of range errors.
indices
=
np
.
array
([
-
1
,
0
,
5
])
indices
=
np
.
array
([
-
1
,
0
,
5
])
with
self
.
assertRaisesOpError
(
'indices is out of range
'
):
with
self
.
assertRaisesOpError
(
r
'indices\[0\] = -1 is not in \[0, 6\)
'
):
op
(
ref
,
indices
,
updates
).
eval
()
op
(
ref
,
indices
,
updates
).
eval
()
indices
=
np
.
array
([
2
,
0
,
6
])
indices
=
np
.
array
([
2
,
0
,
6
])
with
self
.
assertRaisesOpError
(
'indices is out of range
'
):
with
self
.
assertRaisesOpError
(
r
'indices\[2\] = 6 is not in \[0, 6\)
'
):
op
(
ref
,
indices
,
updates
).
eval
()
op
(
ref
,
indices
,
updates
).
eval
()
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录