提交 49760690 编写于 作者: G Geoffrey Irving 提交者: TensorFlower Gardener

Fix build issue with safety fix to gather and scatter

Change: 115495726
上级 746ccc84
......@@ -30,6 +30,15 @@ cc_library(
],
)
cc_library(
name = "bounds_check",
hdrs = ["bounds_check.h"],
deps = [
"//tensorflow/core:framework",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "concat_lib",
srcs = ["concat_lib_cpu.cc"],
......@@ -226,6 +235,7 @@ tf_kernel_libraries(
"where_op",
],
deps = [
":bounds_check",
":concat_lib",
":fill_functor",
":ops_util",
......@@ -874,6 +884,7 @@ tf_kernel_libraries(
],
deps = [
":assign_op",
":bounds_check",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib",
......@@ -955,6 +966,7 @@ filegroup(
"assign_op.h",
"bias_op.cc",
"bias_op.h",
"bounds_check.h",
"cast_op.cc",
"cast_op.h",
"concat_lib.h",
......
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
#include <type_traits>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// Check that 0 <= index < limit using a single comparison, assuming
// that 0 <= limit if Index is signed. Intended for use in performance
// critical contexts where 0 <= index < limit is almost always true.
template <class Index>
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(Index index, Index limit) {
typedef typename std::make_unsigned<Index>::type UIndex;
return TF_PREDICT_TRUE(static_cast<UIndex>(index) <
static_cast<UIndex>(limit));
}
} // namespace tensorflow
#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
......@@ -18,36 +18,52 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
namespace {
// Returns -1 on success or a nonnegative i s.t., indices[i] is bad.
template <typename T, typename Index, int static_slice_elems>
void HandleCopies(const Tensor& Tparams,
typename TTypes<Index>::ConstVec& Tindices, int slice_elems,
typename TTypes<T>::Matrix Tout) {
const int N = Tindices.dimension(0);
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
T* Tout_base = &Tout(0, 0);
const T* Tparams_base = &Tparams_flat(0, 0);
const size_t slice_bytes = slice_elems * sizeof(T);
Index HandleCopies(const Tensor& params,
typename TTypes<Index>::ConstVec indices, Index slice_elems,
typename TTypes<T>::Matrix out) {
const int N = indices.dimension(0);
const auto& params_flat = params.flat_outer_dims<T>();
const Index limit = params.dim_size(0);
T* out_base = &out(0, 0);
const T* params_base = &params_flat(0, 0);
if (static_slice_elems >= 0) {
// Give compiler static knowledge of the number of elements/bytes
CHECK_EQ(static_slice_elems, slice_elems);
slice_elems = static_slice_elems;
}
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
for (int i = 0; i < N; i++) {
int j = i + 1;
const int j = i + 1;
if (j < N) {
port::prefetch<port::PREFETCH_HINT_T0>(&Tparams_flat(Tindices(j), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&Tout(j, 0));
port::prefetch<port::PREFETCH_HINT_T0>(&params_flat(indices(j), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&out(j, 0));
}
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const Index index = indices(i);
if (!FastBoundsCheck(index, limit)) return i;
// Copy using memcpy if possible, otherwise an Eigen loop
if (Allocator::is_simple<T>::value) {
memcpy(out_base + i * slice_elems, params_base + index * slice_elems,
slice_bytes);
} else {
out.template chip<0>(i) = params_flat.template chip<0>(index);
}
memcpy(Tout_base + i * slice_elems,
Tparams_base + Tindices(i) * slice_elems, slice_bytes);
}
return -1;
}
} // anonymous namespace
......@@ -64,78 +80,67 @@ class GatherOp : public OpKernel {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
OP_REQUIRES_OK(c, c->GetAttr("validate_indices", &validate_indices_));
// We used to grab the validate_indices attribute here, but now we
// always validate indices since the speed difference was only 1.5%.
// TODO(irving): Remove the validate_indices attribute once we have
// support for removing attrs in a backwards compatible way.
}
void Compute(OpKernelContext* c) override {
const Tensor& Tparams = c->input(0);
const Tensor& Tindices = c->input(1);
const Tensor& params = c->input(0);
const Tensor& indices = c->input(1);
OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
errors::InvalidArgument("params must be at least 1 dimensional"));
const int64 N = Tindices.NumElements();
const int64 first_dim_size = Tparams.dim_size(0);
// Validate all the indices are in range
auto Tindices_vec = Tindices.flat<Index>();
if (validate_indices_) {
for (int64 i = 0; i < N; i++) {
const Index index = Tindices_vec(i);
OP_REQUIRES(c, index >= 0 && index < first_dim_size,
errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in Tindices is out of range")));
}
}
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
errors::InvalidArgument(
"indices has too many elements for int indexing: ", N_big,
" > ", std::numeric_limits<int>::max()));
const int N = indices.NumElements();
OP_REQUIRES(
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params.shape[0] too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params.dim_size(0), " > ",
std::numeric_limits<Index>::max()));
// The result shape is indices.shape + params.shape[1:].
TensorShape result_shape = Tindices.shape();
for (int i = 1; i < Tparams.dims(); i++) {
result_shape.AddDim(Tparams.dim_size(i));
TensorShape result_shape = indices.shape();
for (int i = 1; i < params.dims(); i++) {
result_shape.AddDim(params.dim_size(i));
}
Tensor* Tout = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout));
const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
if (N > 0) {
auto Tindices_flat = Tindices.flat<Index>();
auto Tout_flat = Tout->shaped<T, 2>({N, Tout->NumElements() / N});
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
const int64 slice_size = Tout->NumElements() / N;
#define SPECIALIZE(elems) \
do { \
if (slice_size == elems) { \
HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \
Tout_flat); \
return; \
} \
} while (0)
SPECIALIZE(10);
SPECIALIZE(20);
#undef SPECIALIZE
HandleCopies<T, Index, -1>(Tparams, Tindices_flat, slice_size,
Tout_flat);
} else {
for (int i = 0; i < N; i++) {
int j = i + 1;
if (j < N) {
port::prefetch<port::PREFETCH_HINT_T0>(
&Tparams_flat(Tindices_vec(j), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&Tout_flat(j, 0));
}
// Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
Tout_flat.template chip<0>(i) =
Tparams_flat.template chip<0>(Tindices_vec(i));
}
}
auto indices_flat = indices.flat<Index>();
auto out_flat = out->shaped<T, 2>({N, out->NumElements() / N});
const int64 slice_size = out->NumElements() / N;
Index bad_i;
#define CALL(elems) \
bad_i = HandleCopies<T, Index, elems>(params, indices_flat, slice_size, \
out_flat)
if (slice_size == 10)
CALL(10);
else if (slice_size == 20)
CALL(20);
else
CALL(-1);
#undef CALL
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
}
}
private:
bool validate_indices_;
};
#define REGISTER_GATHER(type, index_type) \
......
......@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
......@@ -99,36 +101,54 @@ class ScatterUpdateOp : public OpKernel {
}
void DoCompute(OpKernelContext* c) {
Tensor Tparams = c->mutable_input(0, use_exclusive_lock_);
OP_REQUIRES(c, Tparams.IsInitialized(),
Tensor params = c->mutable_input(0, use_exclusive_lock_);
OP_REQUIRES(c, params.IsInitialized(),
errors::FailedPrecondition("Null ref for params"));
const Tensor& Tindices = c->input(1);
const Tensor& Tupdates = c->input(2);
const Tensor& indices = c->input(1);
const Tensor& updates = c->input(2);
OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
errors::InvalidArgument("params must be at least 1-D, got shape ",
Tparams.shape().DebugString()));
params.shape().DebugString()));
OP_REQUIRES(
c, ValidShapes(Tparams, Tupdates, Tindices),
c, ValidShapes(params, updates, indices),
errors::InvalidArgument(
"Must have updates.shape = indices.shape + params.shape[1:], got ",
"updates.shape ", Tupdates.shape().DebugString(),
", indices.shape ", Tindices.shape().DebugString(),
", params.shape ", Tparams.shape().DebugString()));
"updates.shape ", updates.shape().DebugString(), ", indices.shape ",
indices.shape().DebugString(), ", params.shape ",
params.shape().DebugString()));
// Check that we have enough index space
const int64 N_big = indices.NumElements();
OP_REQUIRES(c, N_big <= std::numeric_limits<Index>::max(),
errors::InvalidArgument(
"indices has too many elements for ",
DataTypeString(DataTypeToEnum<Index>::v()), " indexing: ",
N_big, " > ", std::numeric_limits<Index>::max()));
const Index N = indices.NumElements();
OP_REQUIRES(
c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params.shape[0] too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params.dim_size(0), " > ",
std::numeric_limits<Index>::max()));
// We always return the input ref.
c->forward_ref_input_to_ref_output(0, 0);
const Index N = Tindices.NumElements();
if (N > 0) {
auto Tindices_flat = Tindices.flat<Index>();
auto Tparams_flat = Tparams.flat_outer_dims<T>();
auto Tupdates_flat =
Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N});
auto indices_flat = indices.flat<Index>();
auto params_flat = params.flat_outer_dims<T>();
auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
functor::ScatterFunctor<Device, T, Index, op> functor;
functor(c, c->template eigen_device<Device>(),
Tparams_flat, Tupdates_flat, Tindices_flat);
const Index bad_i = functor(c, c->template eigen_device<Device>(),
params_flat, updates_flat, indices_flat);
OP_REQUIRES(
c, bad_i < 0,
errors::InvalidArgument(
"indices", SliceDebugString(indices.shape(), bad_i), " = ",
indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
}
}
};
......@@ -137,26 +157,23 @@ namespace functor {
// Implementation of update functor for CPU.
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctor<CPUDevice, T, Index, op> {
void operator()(OpKernelContext* c, const CPUDevice& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
Index N = indices.size();
// Validate all the indices are in range
Index first_dim_size = params.dimension(0);
Index operator()(OpKernelContext* c, const CPUDevice& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
const Index N = indices.size();
const Index limit = params.dimension(0);
for (Index i = 0; i < N; i++) {
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const Index index = indices(i);
OP_REQUIRES(c, index >= 0 && index < first_dim_size,
errors::InvalidArgument(
strings::StrCat("Index ", index, " at offset ", i,
" in indices is out of range")));
}
for (Index i = 0; i < N; i++) {
// Copy last Ndim-1 dimensions of Tupdates[i] to
// Tparams[Tindices[i]]
Assign<op>::Run(params.template chip<0>(indices(i)),
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
Assign<op>::Run(params.template chip<0>(index),
updates.template chip<0>(i));
}
return -1;
}
};
} // namespace functor
......@@ -220,13 +237,13 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
template <> \
void ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T>::Matrix params, \
typename TTypes<T>::ConstMatrix updates, \
typename TTypes<Index>::ConstFlat indices); \
#define DECLARE_GPU_SPECS_OP(T, Index, op) \
template <> \
Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T>::Matrix params, \
typename TTypes<T>::ConstMatrix updates, \
typename TTypes<Index>::ConstFlat indices); \
extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
......
......@@ -36,10 +36,11 @@ namespace functor {
// Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctor {
void operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices);
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(OpKernelContext* c, const Device& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices);
};
} // namespace functor
......
......@@ -62,10 +62,10 @@ namespace functor {
// Specialization for a GPU device.
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctor<GPUDevice, T, Index, op> {
void operator()(OpKernelContext* c, const GPUDevice& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
Index operator()(OpKernelContext* c, const GPUDevice& d,
typename TTypes<T>::Matrix params,
typename TTypes<T>::ConstMatrix updates,
typename TTypes<Index>::ConstFlat indices) {
// TODO: Implement indices range check. The hardest part is with returning
// a value after the range check, as we do not want to do device to host
// memcpy during a stream.
......@@ -77,6 +77,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
params.data(), updates.data(), indices.data(),
first_dim_size, updates_size, indices_size);
return -1;
}
};
......
......@@ -83,6 +83,14 @@ class GatherTest(tf.test.TestCase):
gather_t = tf.gather(params, indices)
self.assertEqual(None, gather_t.get_shape())
def testBadIndices(self):
with self.test_session():
params = [0, 1, 2]
indices = [[7]]
gather = tf.gather(params, indices)
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
gather.eval()
if __name__ == "__main__":
tf.test.main()
......@@ -128,11 +128,11 @@ class ScatterTest(tf.test.TestCase):
# Test some out of range errors.
indices = np.array([-1, 0, 5])
with self.assertRaisesOpError('indices is out of range'):
with self.assertRaisesOpError(r'indices\[0\] = -1 is not in \[0, 6\)'):
op(ref, indices, updates).eval()
indices = np.array([2, 0, 6])
with self.assertRaisesOpError('indices is out of range'):
with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
op(ref, indices, updates).eval()
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册