提交 49760690 编写于 作者: G Geoffrey Irving 提交者: TensorFlower Gardener

Fix build issue with safety fix to gather and scatter

Change: 115495726
上级 746ccc84
...@@ -30,6 +30,15 @@ cc_library( ...@@ -30,6 +30,15 @@ cc_library(
], ],
) )
cc_library(
name = "bounds_check",
hdrs = ["bounds_check.h"],
deps = [
"//tensorflow/core:framework",
"//third_party/eigen3",
],
)
tf_kernel_library( tf_kernel_library(
name = "concat_lib", name = "concat_lib",
srcs = ["concat_lib_cpu.cc"], srcs = ["concat_lib_cpu.cc"],
...@@ -226,6 +235,7 @@ tf_kernel_libraries( ...@@ -226,6 +235,7 @@ tf_kernel_libraries(
"where_op", "where_op",
], ],
deps = [ deps = [
":bounds_check",
":concat_lib", ":concat_lib",
":fill_functor", ":fill_functor",
":ops_util", ":ops_util",
...@@ -874,6 +884,7 @@ tf_kernel_libraries( ...@@ -874,6 +884,7 @@ tf_kernel_libraries(
], ],
deps = [ deps = [
":assign_op", ":assign_op",
":bounds_check",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib", "//tensorflow/core:state_ops_op_lib",
...@@ -955,6 +966,7 @@ filegroup( ...@@ -955,6 +966,7 @@ filegroup(
"assign_op.h", "assign_op.h",
"bias_op.cc", "bias_op.cc",
"bias_op.h", "bias_op.h",
"bounds_check.h",
"cast_op.cc", "cast_op.cc",
"cast_op.h", "cast_op.h",
"concat_lib.h", "concat_lib.h",
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#include <type_traits>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// Check that 0 <= index < limit using a single comparison, assuming
// that 0 <= limit if Index is signed. Intended for use in performance
// critical contexts where 0 <= index < limit is almost always true.
template <class Index>
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) {
typedef typename std::make_unsigned<Index>::type UIndex;
return TF_PREDICT_TRUE(static_cast<UIndex>(index) <
static_cast<UIndex>(limit));
}
} // namespace tensorflow
#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
...@@ -18,36 +18,52 @@ limitations under the License. ...@@ -18,36 +18,52 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
template <typename T, typename Index, int static_slice_elems> template <typename T, typename Index, int static_slice_elems>
void HandleCopies(const Tensor& Tparams, Index HandleCopies(const Tensor& params,
typename TTypes<Index>::ConstVec& Tindices, int slice_elems, typename TTypes<Index>::ConstVec indices, Index slice_elems,
typename TTypes<T>::Matrix Tout) { typename TTypes<T>::Matrix out) {
const int N = Tindices.dimension(0); const int N = indices.dimension(0);
const auto& Tparams_flat = Tparams.flat_outer_dims<T>(); const auto& params_flat = params.flat_outer_dims<T>();
T* Tout_base = &Tout(0, 0); const Index limit = params.dim_size(0);
const T* Tparams_base = &Tparams_flat(0, 0); T* out_base = &out(0, 0);
const size_t slice_bytes = slice_elems * sizeof(T); const T* params_base = &params_flat(0, 0);
if (static_slice_elems >= 0) { if (static_slice_elems >= 0) {
// Give compiler static knowledge of the number of elements/bytes // Give compiler static knowledge of the number of elements/bytes
CHECK_EQ(static_slice_elems, slice_elems); CHECK_EQ(static_slice_elems, slice_elems);
slice_elems = static_slice_elems; slice_elems = static_slice_elems;
} }
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
int j = i + 1; const int j = i + 1;
if (j < N) { if (j < N) {
port::prefetch<port::PREFETCH_HINT_T0>(&Tparams_flat(Tindices(j), 0)); port::prefetch<port::PREFETCH_HINT_T0>(&params_flat(indices(j), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&Tout(j, 0)); port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
}
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const Index index = indices(i);
if (!FastBoundsCheck(index, limit)) return i;
// Copy using memcpy if possible, otherwise an Eigen loop
if (Allocator::is_simple<T>::value) {
memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
slice_bytes);
} else {
out.template chip<0>(i) = params_flat.template chip<0>(index);
} }
memcpy(Tout_base + i * slice_elems,
Tparams_base + Tindices(i) * slice_elems, slice_bytes);
} }
return -1;
} }
} // anonymous namespace } // anonymous namespace
...@@ -64,78 +80,67 @@ class GatherOp : public OpKernel { ...@@ -64,78 +80,67 @@ class GatherOp : public OpKernel {
const DataType dt = DataTypeToEnum<T>::v(); const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v(); const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt})); OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
OP_REQUIRES_OK(c, c->GetAttr("validate_indices", &validate_indices_)); // We used to grab the validate_indices attribute here, but now we
// always validate indices since the speed difference was only 1.5%.
// TODO(irving): Remove the validate_indices attribute once we have
// support for removing attrs in a backwards compatible way.
} }
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
const Tensor& Tparams = c->input(0); const Tensor& params = c->input(0);
const Tensor& Tindices = c->input(1); const Tensor& indices = c->input(1);
OP_REQUIRES( OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
errors::InvalidArgument("params must be at least 1 dimensional")); errors::InvalidArgument("params must be at least 1 dimensional"));
const int64 N = Tindices.NumElements();
const int64 first_dim_size = Tparams.dim_size(0); // Check that we have enough index space
const int64 N_big = indices.NumElements();
// Validate all the indices are in range OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
auto Tindices_vec = Tindices.flat<Index>(); errors::InvalidArgument(
if (validate_indices_) { "indices has too many elements for int indexing: ", N_big,
for (int64 i = 0; i < N; i++) { " > ", std::numeric_limits<int>::max()));
const Index index = Tindices_vec(i); const int N = indices.NumElements();
OP_REQUIRES(c, index >= 0 && index < first_dim_size, OP_REQUIRES(
errors::InvalidArgument( c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
strings::StrCat("Index ", index, " at offset ", i, errors::InvalidArgument("params.shape[0] too large for ",
" in Tindices is out of range"))); DataTypeString(DataTypeToEnum<Index>::v()),
} " indexing: ", params.dim_size(0), " > ",
} std::numeric_limits<Index>::max()));
// The result shape is indices.shape + params.shape[1:]. // The result shape is indices.shape + params.shape[1:].
TensorShape result_shape = Tindices.shape(); TensorShape result_shape = indices.shape();
for (int i = 1; i < Tparams.dims(); i++) { for (int i = 1; i < params.dims(); i++) {
result_shape.AddDim(Tparams.dim_size(i)); result_shape.AddDim(params.dim_size(i));
} }
Tensor* Tout = nullptr; Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout)); OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
if (N > 0) { if (N > 0) {
auto Tindices_flat = Tindices.flat<Index>(); auto indices_flat = indices.flat<Index>();
auto Tout_flat = Tout->shaped<T, 2>({N, Tout->NumElements() / N}); auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { const int64 slice_size = out->NumElements() / N;
const int64 slice_size = Tout->NumElements() / N; Index bad_i;
#define SPECIALIZE(elems) \
do { \ #define CALL(elems) \
if (slice_size == elems) { \ bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \ out_flat)
Tout_flat); \
return; \ if (slice_size == 10)
} \ CALL(10);
} while (0) else if (slice_size == 20)
CALL(20);
SPECIALIZE(10); else
SPECIALIZE(20); CALL(-1);
#undef SPECIALIZE #undef CALL
HandleCopies<T, Index, -1>(Tparams, Tindices_flat, slice_size, OP_REQUIRES(
Tout_flat); c, bad_i < 0,
} else { errors::InvalidArgument(
for (int i = 0; i < N; i++) { "indices", SliceDebugString(indices.shape(), bad_i), " = ",
int j = i + 1; indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
if (j < N) {
port::prefetch<port::PREFETCH_HINT_T0>(
&Tparams_flat(Tindices_vec(j), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&Tout_flat(j, 0));
}
// Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
Tout_flat.template chip<0>(i) =
Tparams_flat.template chip<0>(Tindices_vec(i));
}
}
} }
} }
private:
bool validate_indices_;
}; };
#define REGISTER_GATHER(type, index_type) \ #define REGISTER_GATHER(type, index_type) \
......
...@@ -20,8 +20,10 @@ limitations under the License. ...@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow { namespace tensorflow {
...@@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel { ...@@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel {
} }
void DoCompute(OpKernelContext* c) { void DoCompute(OpKernelContext* c) {
Tensor Tparams = c->mutable_input(0, use_exclusive_lock_); Tensor params = c->mutable_input(0, use_exclusive_lock_);
OP_REQUIRES(c, Tparams.IsInitialized(), OP_REQUIRES(c, params.IsInitialized(),
errors::FailedPrecondition("Null ref for params")); errors::FailedPrecondition("Null ref for params"));
const Tensor& Tindices = c->input(1); const Tensor& indices = c->input(1);
const Tensor& Tupdates = c->input(2); const Tensor& updates = c->input(2);
OP_REQUIRES( OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
errors::InvalidArgument("params must be at least 1-D, got shape ", errors::InvalidArgument("params must be at least 1-D, got shape ",
Tparams.shape().DebugString())); params.shape().DebugString()));
OP_REQUIRES( OP_REQUIRES(
c, ValidShapes(Tparams, Tupdates, Tindices), c, ValidShapes(params, updates, indices),
errors::InvalidArgument( errors::InvalidArgument(
"Must have updates.shape = indices.shape + params.shape[1:], got ", "Must have updates.shape = indices.shape + params.shape[1:], got ",
"updates.shape ", Tupdates.shape().DebugString(), "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
", indices.shape ", Tindices.shape().DebugString(), indices.shape().DebugString(), ", params.shape ",
", params.shape ", Tparams.shape().DebugString())); params.shape().DebugString()));
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
errors::InvalidArgument(
"indices has too many elements for ",
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
N_big, " > ", std::numeric_limits<Index>::max()));
const Index N = indices.NumElements();
OP_REQUIRES(
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params.shape[0] too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params.dim_size(0), " > ",
std::numeric_limits<Index>::max()));
// We always return the input ref. // We always return the input ref.
c->forward_ref_input_to_ref_output(0, 0); c->forward_ref_input_to_ref_output(0, 0);
const Index N = Tindices.NumElements();
if (N > 0) { if (N > 0) {
auto Tindices_flat = Tindices.flat<Index>(); auto indices_flat = indices.flat<Index>();
auto Tparams_flat = Tparams.flat_outer_dims<T>(); auto params_flat = params.flat_outer_dims<T>();
auto Tupdates_flat = auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N});
functor::ScatterFunctor<Device, T, Index, op> functor; functor::ScatterFunctor<Device, T, Index, op> functor;
functor(c, c->template eigen_device<Device>(), const Index bad_i = functor(c, c->template eigen_device<Device>(),
Tparams_flat, Tupdates_flat, Tindices_flat); params_flat, updates_flat, indices_flat);
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
} }
} }
}; };
...@@ -137,26 +157,23 @@ namespace functor { ...@@ -137,26 +157,23 @@ namespace functor {
// Implementation of update functor for CPU. // Implementation of update functor for CPU.
template <typename T, typename Index, scatter_op::UpdateOp op> template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctor<CPUDevice, T, Index, op> { struct ScatterFunctor<CPUDevice, T, Index, op> {
void operator()(OpKernelContext* c, const CPUDevice& d, Index operator()(OpKernelContext* c, const CPUDevice& d,
typename TTypes<T>::Matrix params, typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates, typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) { typename TTypes<Index>::ConstFlat indices) {
Index N = indices.size(); const Index N = indices.size();
// Validate all the indices are in range const Index limit = params.dimension(0);
Index first_dim_size = params.dimension(0);
for (Index i = 0; i < N; i++) { for (Index i = 0; i < N; i++) {
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const Index index = indices(i); const Index index = indices(i);
OP_REQUIRES(c, index >= 0 && index < first_dim_size, if (!FastBoundsCheck(index, limit)) return i;
errors::InvalidArgument( // Copy last Ndim-1 dimensions of updates[i] to params[index]
strings::StrCat("Index ", index, " at offset ", i, Assign<op>::Run(params.template chip<0>(index),
" in indices is out of range")));
}
for (Index i = 0; i < N; i++) {
// Copy last Ndim-1 dimensions of Tupdates[i] to
// Tparams[Tindices[i]]
Assign<op>::Run(params.template chip<0>(indices(i)),
updates.template chip<0>(i)); updates.template chip<0>(i));
} }
return -1;
} }
}; };
} // namespace functor } // namespace functor
...@@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU); ...@@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
// Forward declarations of the functor specializations for GPU. // Forward declarations of the functor specializations for GPU.
namespace functor { namespace functor {
#define DECLARE_GPU_SPECS_OP(T, Index, op) \ #define DECLARE_GPU_SPECS_OP(T, Index, op) \
template <> \ template <> \
void ScatterFunctor<GPUDevice, T, Index, op>::operator()( \ Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
OpKernelContext* c, const GPUDevice& d, \ OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T>::Matrix params, \ typename TTypes<T>::Matrix params, \
typename TTypes<T>::ConstMatrix updates, \ typename TTypes<T>::ConstMatrix updates, \
typename TTypes<Index>::ConstFlat indices); \ typename TTypes<Index>::ConstFlat indices); \
extern template struct ScatterFunctor<GPUDevice, T, Index, op>; extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \ #define DECLARE_GPU_SPECS_INDEX(T, Index) \
......
...@@ -36,10 +36,11 @@ namespace functor { ...@@ -36,10 +36,11 @@ namespace functor {
// Functor used by ScatterOp to do the computations. // Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctor { struct ScatterFunctor {
void operator()(OpKernelContext* c, const Device& d, // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
typename TTypes<T>::Matrix params, Index operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::ConstMatrix updates, typename TTypes<T>::Matrix params,
typename TTypes<Index>::ConstFlat indices); typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices);
}; };
} // namespace functor } // namespace functor
......
...@@ -62,10 +62,10 @@ namespace functor { ...@@ -62,10 +62,10 @@ namespace functor {
// Specialization for a GPU device. // Specialization for a GPU device.
template <typename T, typename Index, scatter_op::UpdateOp op> template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctor<GPUDevice, T, Index, op> { struct ScatterFunctor<GPUDevice, T, Index, op> {
void operator()(OpKernelContext* c, const GPUDevice& d, Index operator()(OpKernelContext* c, const GPUDevice& d,
typename TTypes<T>::Matrix params, typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates, typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) { typename TTypes<Index>::ConstFlat indices) {
// TODO: Implement indices range check. The hardest part is with returning // TODO: Implement indices range check. The hardest part is with returning
// a value after the range check, as we do not want to do device to host // a value after the range check, as we do not want to do device to host
// memcpy during a stream. // memcpy during a stream.
...@@ -77,6 +77,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { ...@@ -77,6 +77,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
params.data(), updates.data(), indices.data(), params.data(), updates.data(), indices.data(),
first_dim_size, updates_size, indices_size); first_dim_size, updates_size, indices_size);
return -1;
} }
}; };
......
...@@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase): ...@@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase):
gather_t = tf.gather(params, indices) gather_t = tf.gather(params, indices)
self.assertEqual(None, gather_t.get_shape()) self.assertEqual(None, gather_t.get_shape())
def testBadIndices(self):
with self.test_session():
params = [0, 1, 2]
indices = [[7]]
gather = tf.gather(params, indices)
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
gather.eval()
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase): ...@@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
# Test some out of range errors. # Test some out of range errors.
indices = np.array([-1, 0, 5]) indices = np.array([-1, 0, 5])
with self.assertRaisesOpError('indices is out of range'): with self.assertRaisesOpError(r'indices\[0\] = -1 is not in \[0, 6\)'):
op(ref, indices, updates).eval() op(ref, indices, updates).eval()
indices = np.array([2, 0, 6]) indices = np.array([2, 0, 6])
with self.assertRaisesOpError('indices is out of range'): with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
op(ref, indices, updates).eval() op(ref, indices, updates).eval()
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU. # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册