Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
49760690
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
49760690
编写于
2月 24, 2016
作者:
G
Geoffrey Irving
提交者:
TensorFlower Gardener
2月 24, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix build issue with safety fix to gather and scatter
Change: 115495726
上级
746ccc84
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
205 addition
and
123 deletion
+205
-123
tensorflow/core/kernels/BUILD
tensorflow/core/kernels/BUILD
+12
-0
tensorflow/core/kernels/bounds_check.h
tensorflow/core/kernels/bounds_check.h
+38
-0
tensorflow/core/kernels/gather_op.cc
tensorflow/core/kernels/gather_op.cc
+78
-73
tensorflow/core/kernels/scatter_op.cc
tensorflow/core/kernels/scatter_op.cc
+57
-40
tensorflow/core/kernels/scatter_op.h
tensorflow/core/kernels/scatter_op.h
+5
-4
tensorflow/core/kernels/scatter_op_gpu.cu.cc
tensorflow/core/kernels/scatter_op_gpu.cu.cc
+5
-4
tensorflow/python/kernel_tests/gather_op_test.py
tensorflow/python/kernel_tests/gather_op_test.py
+8
-0
tensorflow/python/kernel_tests/scatter_ops_test.py
tensorflow/python/kernel_tests/scatter_ops_test.py
+2
-2
未找到文件。
tensorflow/core/kernels/BUILD
浏览文件 @
49760690
...
...
@@ -30,6 +30,15 @@ cc_library(
],
)
cc_library
(
name
=
"bounds_check"
,
hdrs
=
[
"bounds_check.h"
],
deps
=
[
"//tensorflow/core:framework"
,
"//third_party/eigen3"
,
],
)
tf_kernel_library
(
name
=
"concat_lib"
,
srcs
=
[
"concat_lib_cpu.cc"
],
...
...
@@ -226,6 +235,7 @@ tf_kernel_libraries(
"where_op"
,
],
deps
=
[
":bounds_check"
,
":concat_lib"
,
":fill_functor"
,
":ops_util"
,
...
...
@@ -874,6 +884,7 @@ tf_kernel_libraries(
],
deps
=
[
":assign_op"
,
":bounds_check"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:state_ops_op_lib"
,
...
...
@@ -955,6 +966,7 @@ filegroup(
"assign_op.h"
,
"bias_op.cc"
,
"bias_op.h"
,
"bounds_check.h"
,
"cast_op.cc"
,
"cast_op.h"
,
"concat_lib.h"
,
...
...
tensorflow/core/kernels/bounds_check.h
0 → 100644
浏览文件 @
49760690
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#include <type_traits>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/macros.h"
namespace
tensorflow
{
// Check that 0 <= index < limit using a single comparison, assuming
// that 0 <= limit if Index is signed. Intended for use in performance
// critical contexts where 0 <= index < limit is almost always true.
template
<
class
Index
>
EIGEN_ALWAYS_INLINE
bool
FastBoundsCheck
(
Index
index
,
Index
limit
)
{
typedef
typename
std
::
make_unsigned
<
Index
>::
type
UIndex
;
return
TF_PREDICT_TRUE
(
static_cast
<
UIndex
>
(
index
)
<
static_cast
<
UIndex
>
(
limit
));
}
}
// namespace tensorflow
#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
tensorflow/core/kernels/gather_op.cc
浏览文件 @
49760690
...
...
@@ -18,36 +18,52 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace
tensorflow
{
namespace
{
// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
template
<
typename
T
,
typename
Index
,
int
static_slice_elems
>
void
HandleCopies
(
const
Tensor
&
T
params
,
typename
TTypes
<
Index
>::
ConstVec
&
Tindices
,
int
slice_elems
,
typename
TTypes
<
T
>::
Matrix
T
out
)
{
const
int
N
=
T
indices
.
dimension
(
0
);
const
auto
&
Tparams_flat
=
T
params
.
flat_outer_dims
<
T
>
();
T
*
Tout_base
=
&
Tout
(
0
,
0
);
const
T
*
Tparams_base
=
&
Tparams_fla
t
(
0
,
0
);
const
size_t
slice_bytes
=
slice_elems
*
sizeof
(
T
);
Index
HandleCopies
(
const
Tensor
&
params
,
typename
TTypes
<
Index
>::
ConstVec
indices
,
Index
slice_elems
,
typename
TTypes
<
T
>::
Matrix
out
)
{
const
int
N
=
indices
.
dimension
(
0
);
const
auto
&
params_flat
=
params
.
flat_outer_dims
<
T
>
();
const
Index
limit
=
params
.
dim_size
(
0
);
T
*
out_base
=
&
ou
t
(
0
,
0
);
const
T
*
params_base
=
&
params_flat
(
0
,
0
);
if
(
static_slice_elems
>=
0
)
{
// Give compiler static knowledge of the number of elements/bytes
CHECK_EQ
(
static_slice_elems
,
slice_elems
);
slice_elems
=
static_slice_elems
;
}
// Compute slice_bytes here so that static knowledge is available
const
size_t
slice_bytes
=
slice_elems
*
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
int
j
=
i
+
1
;
const
int
j
=
i
+
1
;
if
(
j
<
N
)
{
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tparams_flat
(
Tindices
(
j
),
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tout
(
j
,
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
params_flat
(
indices
(
j
),
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
out
(
j
,
0
));
}
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const
Index
index
=
indices
(
i
);
if
(
!
FastBoundsCheck
(
index
,
limit
))
return
i
;
// Copy using memcpy if possible, otherwise an Eigen loop
if
(
Allocator
::
is_simple
<
T
>::
value
)
{
memcpy
(
out_base
+
i
*
slice_elems
,
params_base
+
index
*
slice_elems
,
slice_bytes
);
}
else
{
out
.
template
chip
<
0
>(
i
)
=
params_flat
.
template
chip
<
0
>(
index
);
}
memcpy
(
Tout_base
+
i
*
slice_elems
,
Tparams_base
+
Tindices
(
i
)
*
slice_elems
,
slice_bytes
);
}
return
-
1
;
}
}
// anonymous namespace
...
...
@@ -64,78 +80,67 @@ class GatherOp : public OpKernel {
const
DataType
dt
=
DataTypeToEnum
<
T
>::
v
();
const
DataType
index_t
=
DataTypeToEnum
<
Index
>::
v
();
OP_REQUIRES_OK
(
c
,
c
->
MatchSignature
({
dt
,
index_t
},
{
dt
}));
OP_REQUIRES_OK
(
c
,
c
->
GetAttr
(
"validate_indices"
,
&
validate_indices_
));
// We used to grab the validate_indices attribute here, but now we
// always validate indices since the speed difference was only 1.5%.
// TODO(irving): Remove the validate_indices attribute once we have
// support for removing attrs in a backwards compatible way.
}
void
Compute
(
OpKernelContext
*
c
)
override
{
const
Tensor
&
T
params
=
c
->
input
(
0
);
const
Tensor
&
T
indices
=
c
->
input
(
1
);
const
Tensor
&
params
=
c
->
input
(
0
);
const
Tensor
&
indices
=
c
->
input
(
1
);
OP_REQUIRES
(
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
T
params
.
shape
()),
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
params
.
shape
()),
errors
::
InvalidArgument
(
"params must be at least 1 dimensional"
));
const
int64
N
=
Tindices
.
NumElements
();
const
int64
first_dim_size
=
Tparams
.
dim_size
(
0
);
// Validate all the indices are in range
auto
Tindices_vec
=
Tindices
.
flat
<
Index
>
();
if
(
validate_indices_
)
{
for
(
int64
i
=
0
;
i
<
N
;
i
++
)
{
const
Index
index
=
Tindices_vec
(
i
);
OP_REQUIRES
(
c
,
index
>=
0
&&
index
<
first_dim_size
,
errors
::
InvalidArgument
(
strings
::
StrCat
(
"Index "
,
index
,
" at offset "
,
i
,
" in Tindices is out of range"
)));
}
}
// Check that we have enough index space
const
int64
N_big
=
indices
.
NumElements
();
OP_REQUIRES
(
c
,
N_big
<=
std
::
numeric_limits
<
int
>::
max
(),
errors
::
InvalidArgument
(
"indices has too many elements for int indexing: "
,
N_big
,
" > "
,
std
::
numeric_limits
<
int
>::
max
()));
const
int
N
=
indices
.
NumElements
(
);
OP_REQUIRES
(
c
,
params
.
dim_size
(
0
)
<=
std
::
numeric_limits
<
Index
>::
max
(),
errors
::
InvalidArgument
(
"params.shape[0] too large for "
,
DataTypeString
(
DataTypeToEnum
<
Index
>::
v
()),
" indexing: "
,
params
.
dim_size
(
0
),
" > "
,
std
::
numeric_limits
<
Index
>::
max
()));
// The result shape is indices.shape + params.shape[1:].
TensorShape
result_shape
=
T
indices
.
shape
();
for
(
int
i
=
1
;
i
<
T
params
.
dims
();
i
++
)
{
result_shape
.
AddDim
(
T
params
.
dim_size
(
i
));
TensorShape
result_shape
=
indices
.
shape
();
for
(
int
i
=
1
;
i
<
params
.
dims
();
i
++
)
{
result_shape
.
AddDim
(
params
.
dim_size
(
i
));
}
Tensor
*
Tout
=
nullptr
;
OP_REQUIRES_OK
(
c
,
c
->
allocate_output
(
0
,
result_shape
,
&
Tout
));
const
auto
&
Tparams_flat
=
Tparams
.
flat_outer_dims
<
T
>
();
Tensor
*
out
=
nullptr
;
OP_REQUIRES_OK
(
c
,
c
->
allocate_output
(
0
,
result_shape
,
&
out
));
if
(
N
>
0
)
{
auto
Tindices_flat
=
Tindices
.
flat
<
Index
>
();
auto
Tout_flat
=
Tout
->
shaped
<
T
,
2
>
({
N
,
Tout
->
NumElements
()
/
N
});
if
(
DataTypeCanUseMemcpy
(
DataTypeToEnum
<
T
>::
v
()))
{
const
int64
slice_size
=
Tout
->
NumElements
()
/
N
;
#define SPECIALIZE(elems) \
do { \
if (slice_size == elems) { \
HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \
Tout_flat); \
return; \
} \
} while (0)
SPECIALIZE
(
10
);
SPECIALIZE
(
20
);
#undef SPECIALIZE
HandleCopies
<
T
,
Index
,
-
1
>
(
Tparams
,
Tindices_flat
,
slice_size
,
Tout_flat
);
}
else
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
int
j
=
i
+
1
;
if
(
j
<
N
)
{
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tparams_flat
(
Tindices_vec
(
j
),
0
));
port
::
prefetch
<
port
::
PREFETCH_HINT_T0
>
(
&
Tout_flat
(
j
,
0
));
}
// Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
Tout_flat
.
template
chip
<
0
>(
i
)
=
Tparams_flat
.
template
chip
<
0
>(
Tindices_vec
(
i
));
}
}
auto
indices_flat
=
indices
.
flat
<
Index
>
();
auto
out_flat
=
out
->
shaped
<
T
,
2
>
({
N
,
out
->
NumElements
()
/
N
});
const
int64
slice_size
=
out
->
NumElements
()
/
N
;
Index
bad_i
;
#define CALL(elems) \
bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
out_flat)
if
(
slice_size
==
10
)
CALL
(
10
);
else
if
(
slice_size
==
20
)
CALL
(
20
);
else
CALL
(
-
1
);
#undef CALL
OP_REQUIRES
(
c
,
bad_i
<
0
,
errors
::
InvalidArgument
(
"indices"
,
SliceDebugString
(
indices
.
shape
(),
bad_i
),
" = "
,
indices_flat
(
bad_i
),
" is not in [0, "
,
params
.
dim_size
(
0
),
")"
));
}
}
private:
bool
validate_indices_
;
};
#define REGISTER_GATHER(type, index_type) \
...
...
tensorflow/core/kernels/scatter_op.cc
浏览文件 @
49760690
...
...
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace
tensorflow
{
...
...
@@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel {
}
void
DoCompute
(
OpKernelContext
*
c
)
{
Tensor
T
params
=
c
->
mutable_input
(
0
,
use_exclusive_lock_
);
OP_REQUIRES
(
c
,
T
params
.
IsInitialized
(),
Tensor
params
=
c
->
mutable_input
(
0
,
use_exclusive_lock_
);
OP_REQUIRES
(
c
,
params
.
IsInitialized
(),
errors
::
FailedPrecondition
(
"Null ref for params"
));
const
Tensor
&
T
indices
=
c
->
input
(
1
);
const
Tensor
&
T
updates
=
c
->
input
(
2
);
const
Tensor
&
indices
=
c
->
input
(
1
);
const
Tensor
&
updates
=
c
->
input
(
2
);
OP_REQUIRES
(
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
T
params
.
shape
()),
c
,
TensorShapeUtils
::
IsVectorOrHigher
(
params
.
shape
()),
errors
::
InvalidArgument
(
"params must be at least 1-D, got shape "
,
T
params
.
shape
().
DebugString
()));
params
.
shape
().
DebugString
()));
OP_REQUIRES
(
c
,
ValidShapes
(
Tparams
,
Tupdates
,
T
indices
),
c
,
ValidShapes
(
params
,
updates
,
indices
),
errors
::
InvalidArgument
(
"Must have updates.shape = indices.shape + params.shape[1:], got "
,
"updates.shape "
,
Tupdates
.
shape
().
DebugString
(),
", indices.shape "
,
Tindices
.
shape
().
DebugString
(),
", params.shape "
,
Tparams
.
shape
().
DebugString
()));
"updates.shape "
,
updates
.
shape
().
DebugString
(),
", indices.shape "
,
indices
.
shape
().
DebugString
(),
", params.shape "
,
params
.
shape
().
DebugString
()));
// Check that we have enough index space
const
int64
N_big
=
indices
.
NumElements
();
OP_REQUIRES
(
c
,
N_big
<=
std
::
numeric_limits
<
Index
>::
max
(),
errors
::
InvalidArgument
(
"indices has too many elements for "
,
DataTypeString
(
DataTypeToEnum
<
Index
>::
v
()),
" indexing: "
,
N_big
,
" > "
,
std
::
numeric_limits
<
Index
>::
max
()));
const
Index
N
=
indices
.
NumElements
();
OP_REQUIRES
(
c
,
params
.
dim_size
(
0
)
<=
std
::
numeric_limits
<
Index
>::
max
(),
errors
::
InvalidArgument
(
"params.shape[0] too large for "
,
DataTypeString
(
DataTypeToEnum
<
Index
>::
v
()),
" indexing: "
,
params
.
dim_size
(
0
),
" > "
,
std
::
numeric_limits
<
Index
>::
max
()));
// We always return the input ref.
c
->
forward_ref_input_to_ref_output
(
0
,
0
);
const
Index
N
=
Tindices
.
NumElements
();
if
(
N
>
0
)
{
auto
Tindices_flat
=
Tindices
.
flat
<
Index
>
();
auto
Tparams_flat
=
Tparams
.
flat_outer_dims
<
T
>
();
auto
Tupdates_flat
=
Tupdates
.
shaped
<
T
,
2
>
({
N
,
Tupdates
.
NumElements
()
/
N
});
auto
indices_flat
=
indices
.
flat
<
Index
>
();
auto
params_flat
=
params
.
flat_outer_dims
<
T
>
();
auto
updates_flat
=
updates
.
shaped
<
T
,
2
>
({
N
,
updates
.
NumElements
()
/
N
});
functor
::
ScatterFunctor
<
Device
,
T
,
Index
,
op
>
functor
;
functor
(
c
,
c
->
template
eigen_device
<
Device
>(),
Tparams_flat
,
Tupdates_flat
,
Tindices_flat
);
const
Index
bad_i
=
functor
(
c
,
c
->
template
eigen_device
<
Device
>(),
params_flat
,
updates_flat
,
indices_flat
);
OP_REQUIRES
(
c
,
bad_i
<
0
,
errors
::
InvalidArgument
(
"indices"
,
SliceDebugString
(
indices
.
shape
(),
bad_i
),
" = "
,
indices_flat
(
bad_i
),
" is not in [0, "
,
params
.
dim_size
(
0
),
")"
));
}
}
};
...
...
@@ -137,26 +157,23 @@ namespace functor {
// Implementation of update functor for CPU.
template
<
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
struct
ScatterFunctor
<
CPUDevice
,
T
,
Index
,
op
>
{
void
operator
()(
OpKernelContext
*
c
,
const
CPUDevice
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
Index
N
=
indices
.
size
();
// Validate all the indices are in range
Index
first_dim_size
=
params
.
dimension
(
0
);
Index
operator
()(
OpKernelContext
*
c
,
const
CPUDevice
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
const
Index
N
=
indices
.
size
();
const
Index
limit
=
params
.
dimension
(
0
);
for
(
Index
i
=
0
;
i
<
N
;
i
++
)
{
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const
Index
index
=
indices
(
i
);
OP_REQUIRES
(
c
,
index
>=
0
&&
index
<
first_dim_size
,
errors
::
InvalidArgument
(
strings
::
StrCat
(
"Index "
,
index
,
" at offset "
,
i
,
" in indices is out of range"
)));
}
for
(
Index
i
=
0
;
i
<
N
;
i
++
)
{
// Copy last Ndim-1 dimensions of Tupdates[i] to
// Tparams[Tindices[i]]
Assign
<
op
>::
Run
(
params
.
template
chip
<
0
>(
indices
(
i
)),
if
(
!
FastBoundsCheck
(
index
,
limit
))
return
i
;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
Assign
<
op
>::
Run
(
params
.
template
chip
<
0
>(
index
),
updates
.
template
chip
<
0
>(
i
));
}
return
-
1
;
}
};
}
// namespace functor
...
...
@@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
// Forward declarations of the functor specializations for GPU.
namespace
functor
{
#define DECLARE_GPU_SPECS_OP(T, Index, op)
\
template <>
\
void ScatterFunctor<GPUDevice, T, Index, op>::operator()(
\
OpKernelContext* c, const GPUDevice& d,
\
typename TTypes<T>::Matrix params,
\
typename TTypes<T>::ConstMatrix updates,
\
typename TTypes<Index>::ConstFlat indices);
\
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
template <> \
Index ScatterFunctor<GPUDevice, T, Index, op>::operator()(
\
OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T>::Matrix params, \
typename TTypes<T>::ConstMatrix updates, \
typename TTypes<Index>::ConstFlat indices); \
extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
...
...
tensorflow/core/kernels/scatter_op.h
浏览文件 @
49760690
...
...
@@ -36,10 +36,11 @@ namespace functor {
// Functor used by ScatterOp to do the computations.
template
<
typename
Device
,
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
struct
ScatterFunctor
{
void
operator
()(
OpKernelContext
*
c
,
const
Device
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
);
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index
operator
()(
OpKernelContext
*
c
,
const
Device
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
);
};
}
// namespace functor
...
...
tensorflow/core/kernels/scatter_op_gpu.cu.cc
浏览文件 @
49760690
...
...
@@ -62,10 +62,10 @@ namespace functor {
// Specialization for a GPU device.
template
<
typename
T
,
typename
Index
,
scatter_op
::
UpdateOp
op
>
struct
ScatterFunctor
<
GPUDevice
,
T
,
Index
,
op
>
{
void
operator
()(
OpKernelContext
*
c
,
const
GPUDevice
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
Index
operator
()(
OpKernelContext
*
c
,
const
GPUDevice
&
d
,
typename
TTypes
<
T
>::
Matrix
params
,
typename
TTypes
<
T
>::
ConstMatrix
updates
,
typename
TTypes
<
Index
>::
ConstFlat
indices
)
{
// TODO: Implement indices range check. The hardest part is with returning
// a value after the range check, as we do not want to do device to host
// memcpy during a stream.
...
...
@@ -77,6 +77,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
<<<
config
.
block_count
,
config
.
thread_per_block
,
0
,
d
.
stream
()
>>>
(
params
.
data
(),
updates
.
data
(),
indices
.
data
(),
first_dim_size
,
updates_size
,
indices_size
);
return
-
1
;
}
};
...
...
tensorflow/python/kernel_tests/gather_op_test.py
浏览文件 @
49760690
...
...
@@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase):
gather_t
=
tf
.
gather
(
params
,
indices
)
self
.
assertEqual
(
None
,
gather_t
.
get_shape
())
def
testBadIndices
(
self
):
with
self
.
test_session
():
params
=
[
0
,
1
,
2
]
indices
=
[[
7
]]
gather
=
tf
.
gather
(
params
,
indices
)
with
self
.
assertRaisesOpError
(
r
"indices\[0,0\] = 7 is not in \[0, 3\)"
):
gather
.
eval
()
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tensorflow/python/kernel_tests/scatter_ops_test.py
浏览文件 @
49760690
...
...
@@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
# Test some out of range errors.
indices
=
np
.
array
([
-
1
,
0
,
5
])
with
self
.
assertRaisesOpError
(
'indices is out of range
'
):
with
self
.
assertRaisesOpError
(
r
'indices\[0\] = -1 is not in \[0, 6\)
'
):
op
(
ref
,
indices
,
updates
).
eval
()
indices
=
np
.
array
([
2
,
0
,
6
])
with
self
.
assertRaisesOpError
(
'indices is out of range
'
):
with
self
.
assertRaisesOpError
(
r
'indices\[2\] = 6 is not in \[0, 6\)
'
):
op
(
ref
,
indices
,
updates
).
eval
()
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录