提交 e863eed8 编写于 作者: A Antonio Sanchez 提交者: TensorFlower Gardener

Add index check to ValidateSparseTensor.

The validation check now verifies all index entries are within correct bounds.
This is to prevent memory access errors.

PiperOrigin-RevId: 450695916
上级 7f900232
......@@ -144,8 +144,12 @@ bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
return false;
}
Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
const Tensor& shape) {
namespace {
// Ensures indices, values, shape are all of the proper ranks and are
// compatible.
Status ValidateSparseTensorShape(const Tensor& indices, const Tensor& values,
const Tensor& shape) {
// Indices must be a matrix, and values/shape must be a vector.
if (!TensorShapeUtils::IsMatrix(indices.shape())) {
return errors::InvalidArgument("Sparse indices must be rank 2 but is rank ",
......@@ -175,6 +179,45 @@ Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
return Status::OK();
}
// Ensures all sparse indices are within correct bounds.
template <typename Tindices>
Status ValidateSparseTensorIndices(const Tensor& indices, const Tensor& shape) {
// Ensure no index is out-of-bounds.
const auto indices_mat = indices.flat_inner_dims<Tindices>();
const auto shape_vec = shape.flat<Tindices>();
int64_t nnz = indices.dim_size(0);
int64_t ndims = indices.dim_size(1);
for (int64_t i = 0; i < nnz; ++i) {
for (int64_t dim = 0; dim < ndims; ++dim) {
const Tindices idx = indices_mat(i, dim);
if (TF_PREDICT_FALSE(idx < 0 || idx >= shape_vec(dim))) {
string index_str = strings::StrCat("indices[", i, ", :] = [");
for (int64_t dim = 0; dim < ndims; ++dim) {
strings::StrAppend(&index_str, indices_mat(i, dim),
dim < ndims - 1 ? ", " : "]");
}
return errors::InvalidArgument("Sparse index tuple ", index_str,
" is out of bounds");
}
}
}
return Status::OK();
}
} // namespace
template <typename Tindices>
Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
const Tensor& shape, bool validate_indices) {
TF_RETURN_IF_ERROR(ValidateSparseTensorShape(indices, values, shape));
if (validate_indices) {
return ValidateSparseTensorIndices<Tindices>(indices, shape);
}
return Status::OK();
}
#define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \
template TypeIndex FindNextDenseRowStartIndex<TypeIndex>( \
const TypeIndex sparse_index_begin, \
......@@ -186,7 +229,10 @@ Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
const std::vector<TypeIndex>& row_start_indices); \
template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>( \
const tensorflow::Tensor& tensor, \
const TypeIndex num_nonzero_entries_in_sparse_mat)
const TypeIndex num_nonzero_entries_in_sparse_mat); \
template Status ValidateSparseTensor<TypeIndex>( \
const Tensor& indices, const Tensor& values, const Tensor& shape, \
bool validate_indices)
REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
......
......@@ -66,9 +66,11 @@ template <typename Tindices>
bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices);
// Validates the three component tensors of a sparse tensor have the proper
// shapes.
// shapes. If validate_indices is true, also checks that all indices are within
// correct bounds (i.e. are safe to access).
template <typename Tindices>
Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
const Tensor& shape);
const Tensor& shape, bool validate_indices);
} // namespace sparse_utils
} // namespace tensorflow
......
......@@ -334,18 +334,21 @@ void GenerateRandomSparseTensor(int64_t max_nnz, const TensorShape& shape,
TEST(ValidateSparseTensorTest, ValidSparseTensorPasses) {
constexpr int kNumNonZeros = 1000;
constexpr bool kValidateIndices = true;
const TensorShape kTensorShapes[] = {
{}, {3}, {4, 5}, {6, 7, 8}, {9, 10, 11, 12}};
for (const TensorShape& tshape : kTensorShapes) {
Tensor indices, values, shape;
GenerateRandomSparseTensor(kNumNonZeros, tshape, indices, values, shape);
TF_EXPECT_OK((ValidateSparseTensor(indices, values, shape)));
TF_EXPECT_OK((ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)));
}
}
TEST(ValidateSparseTensorTest, InvalidIndicesRankFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
constexpr bool kValidateIndices = false;
// Indices tensor must be rank 2, so try rank 0, 1, 3.
const TensorShape kInvalidIndicesShapes[] = {
{}, {kNumNonZeros}, {kNumNonZeros, kNumDims, 4}};
......@@ -353,7 +356,8 @@ TEST(ValidateSparseTensorTest, InvalidIndicesRankFails) {
const Tensor indices = Tensor(DT_INT64, invalid_shape);
const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse indices must be rank 2 .*")));
}
......@@ -362,6 +366,7 @@ TEST(ValidateSparseTensorTest, InvalidIndicesRankFails) {
TEST(ValidateSparseTensorTest, InvalidValuesRankFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
constexpr bool kValidateIndices = false;
// Values tensor must be rank 1, so try rank 0, 2.
const TensorShape kInvalidValuesShapes[] = {{}, {kNumNonZeros, 2}};
for (const TensorShape& invalid_shape : kInvalidValuesShapes) {
......@@ -369,7 +374,8 @@ TEST(ValidateSparseTensorTest, InvalidValuesRankFails) {
Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims}));
const Tensor values = Tensor(DT_FLOAT, invalid_shape);
const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse values must be rank 1 .*")));
}
......@@ -378,6 +384,7 @@ TEST(ValidateSparseTensorTest, InvalidValuesRankFails) {
TEST(ValidateSparseTensorTest, InvalidShapeRankFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
constexpr bool kValidateIndices = false;
// Shape tensor must be rank 1, so try rank 0, 2.
const TensorShape kInvalidShapeShapes[] = {{}, {kNumDims, 2}};
for (const TensorShape& invalid_shape : kInvalidShapeShapes) {
......@@ -385,7 +392,8 @@ TEST(ValidateSparseTensorTest, InvalidShapeRankFails) {
Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims}));
const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
const Tensor shape = Tensor(DT_INT64, invalid_shape);
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse shape must be rank 1 .*")));
}
......@@ -394,6 +402,7 @@ TEST(ValidateSparseTensorTest, InvalidShapeRankFails) {
TEST(ValidateSparseTensorTest, IncompatibleShapesFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
constexpr bool kValidateIndices = false;
const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
......@@ -402,7 +411,8 @@ TEST(ValidateSparseTensorTest, IncompatibleShapesFails) {
{
const Tensor indices =
Tensor(DT_INT64, TensorShape({kNumNonZeros + 1, kNumDims}));
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Number of elements in indices .* and "
"values .* do not match")));
......@@ -414,12 +424,50 @@ TEST(ValidateSparseTensorTest, IncompatibleShapesFails) {
const Tensor indices =
Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims + 1}));
EXPECT_THAT(
(ValidateSparseTensor(indices, values, shape)),
(ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Index rank .* and shape rank .* do not match")));
}
}
TEST(ValidateSparseTensorTest, IndexOutOfBoundsFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumTests = 100;
constexpr bool kValidateIndices = true;
const TensorShape kTensorShapes[] = {{3}, {4, 5}, {6, 7, 8}, {9, 10, 11, 12}};
for (const TensorShape& tshape : kTensorShapes) {
Tensor indices, values, shape;
GenerateRandomSparseTensor(kNumNonZeros, tshape, indices, values, shape);
// Access tensor values.
auto indices_mat = indices.matrix<int64_t>();
for (int test = 0; test < kNumTests; ++test) {
// Pick a random entry and dimension, and make the index out of bounds.
int64_t row = RandomPhilox().Uniform64(indices.dim_size(0));
int64_t dim = RandomPhilox().Uniform64(indices.dim_size(1));
int64_t old_val = indices_mat(row, dim);
indices_mat(row, dim) = -1;
EXPECT_THAT(
(ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse index tuple .* is out of bounds")));
indices_mat(row, dim) = tshape.dim_size(dim);
EXPECT_THAT(
(ValidateSparseTensor<int64_t>(indices, values, shape,
kValidateIndices)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse index tuple .* is out of bounds")));
// Restore index for next test.
indices_mat(row, dim) = old_val;
}
}
}
} // namespace
} // namespace sparse_utils
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册