提交 0468e56d 编写于 作者: A Antonio Sanchez 提交者: TensorFlow Release Automation

Add ValidateSparseTensor function to sparse_utils.

This will be used to validate sparse input shapes (and eventually indices),
and unify many of our existing checks across sparse ops.

PiperOrigin-RevId: 450019722
上级 5e3bfa5a
......@@ -443,6 +443,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:status_matchers",
"@com_google_absl//absl/base:core_headers",
],
)
......
......@@ -16,8 +16,12 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_utils.h"
#include <cstddef>
#include <cstdint>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace sparse_utils {
......@@ -140,6 +144,37 @@ bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
return false;
}
Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
const Tensor& shape) {
// Indices must be a matrix, and values/shape must be a vector.
if (!TensorShapeUtils::IsMatrix(indices.shape())) {
return errors::InvalidArgument("Sparse indices must be rank 2 but is rank ",
indices.shape().dim_sizes().size());
}
if (!TensorShapeUtils::IsVector(values.shape())) {
return errors::InvalidArgument("Sparse values must be rank 1 but is rank ",
values.shape().dims());
}
if (!TensorShapeUtils::IsVector(shape.shape())) {
return errors::InvalidArgument("Sparse shape must be rank 1 but is rank ",
shape.shape().dims());
}
// Indices shape must be compatible with the values vector and dense shape.
int64_t nnz = indices.dim_size(0);
int64_t ndims = indices.dim_size(1);
if (values.dim_size(0) != nnz) {
return errors::InvalidArgument("Number of elements in indices (", nnz,
") and values (", values.dim_size(0),
") do not match");
}
if (shape.NumElements() != ndims) {
return errors::InvalidArgument("Index rank (", ndims, ") and shape rank (",
shape.NumElements(), ") do not match");
}
return Status::OK();
}
#define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \
template TypeIndex FindNextDenseRowStartIndex<TypeIndex>( \
const TypeIndex sparse_index_begin, \
......@@ -151,7 +186,7 @@ bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
const std::vector<TypeIndex>& row_start_indices); \
template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>( \
const tensorflow::Tensor& tensor, \
const TypeIndex num_nonzero_entries_in_sparse_mat);
const TypeIndex num_nonzero_entries_in_sparse_mat)
REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
......
......@@ -65,6 +65,11 @@ std::vector<Tindices> ParseRowStartIndices(
template <typename Tindices>
bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices);
// Validates the three component tensors of a sparse tensor have the proper
// shapes.
Status ValidateSparseTensor(const Tensor& indices, const Tensor& values,
const Tensor& shape);
} // namespace sparse_utils
} // namespace tensorflow
......
......@@ -15,27 +15,28 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_utils.h"
#include <algorithm>
#include <cstdint>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace sparse_utils {
namespace {
using ::int64_t;
using tensorflow::DataType;
using tensorflow::int32;
using tensorflow::Tensor;
using tensorflow::TTypes;
using tensorflow::uint16;
using tensorflow::uint32;
using tensorflow::uint64;
using tensorflow::sparse_utils::ContainsEmptyRows;
using tensorflow::sparse_utils::FindNextDenseRowStartIndex;
using tensorflow::sparse_utils::GetStartIndicesOfEachDenseRow;
using tensorflow::sparse_utils::ParseRowStartIndices;
using ::tensorflow::testing::StatusIs;
using ::testing::MatchesRegex;
TEST(SparseUtilsTest, GetStartIndicesOfEachDenseRow) {
{
......@@ -260,4 +261,165 @@ TEST(SparseUtilsTest, FindNextDenseRowStartIndex) {
}
}
// Returns a shared random number generator.
::tensorflow::random::SimplePhilox& RandomPhilox() {
// Safe initialization of static random generator.
static auto* philox =
new ::tensorflow::random::PhiloxRandom(tensorflow::testing::RandomSeed());
static auto* rnd = new ::tensorflow::random::SimplePhilox(philox);
return *rnd;
}
// Fills a tensor of indices with a unique set of random index tuples.
// The `SetType` must be a std::set-like type (e.g. flat_hash_set, btree_set)
// that is used to ensure uniqueness and governs the final index tuple order.
// For example, use a hash set for unordered indices, and sorted set for
// lexicographically ordered indices. The `shape` is used to ensure proper index
// bounds.
template <typename SetType>
void FillIndicesWithRandomTuples(const TensorShape& shape, Tensor& indices) {
const int64_t nnz = indices.dim_size(0);
const int64_t ndims = indices.dim_size(1);
SetType indices_set;
int64_t count = 0;
// Generate nnz unique random tuples.
while (count < nnz) {
std::vector<int64_t> candidate(ndims);
for (int64_t d = 0; d < ndims; ++d) {
candidate[d] = RandomPhilox().Uniform64(shape.dim_size(d));
}
auto it = indices_set.insert(std::move(candidate));
if (it.second) {
++count;
}
}
// Copy index tuples from set into index tensor.
auto indices_mat = indices.matrix<int64_t>();
int64_t row = 0;
for (const std::vector<int64_t>& idxs : indices_set) {
for (int64_t col = 0; col < ndims; ++col) {
indices_mat(row, col) = idxs[col];
}
++row;
}
}
// Populates components of a sparse random tensor with provided number of
// non-zeros `max_nnz` and tensor shape `shape`.
void GenerateRandomSparseTensor(int64_t max_nnz, const TensorShape& shape,
Tensor& output_indices, Tensor& output_values,
Tensor& output_shape) {
const int64_t ndims = shape.dims();
// We cannot generate more elements than the total in the tensor, so
// potentially reduce nnz.
const int64_t nnz = std::min(shape.num_elements(), max_nnz);
output_indices = Tensor(DT_INT64, TensorShape({nnz, ndims}));
output_values = Tensor(DT_FLOAT, TensorShape({nnz}));
output_shape = Tensor(DT_INT64, TensorShape({ndims}));
// Generate random unique unordered sparse indices.
FillIndicesWithRandomTuples<absl::flat_hash_set<std::vector<int64_t>>>(
shape, output_indices);
auto values_vec = output_values.vec<float>();
values_vec.setRandom();
auto shape_vec = output_shape.vec<int64_t>();
for (int i = 0; i < shape.dims(); ++i) {
shape_vec(i) = shape.dim_size(i);
}
}
TEST(ValidateSparseTensorTest, ValidSparseTensorPasses) {
constexpr int kNumNonZeros = 1000;
const TensorShape kTensorShapes[] = {
{}, {3}, {4, 5}, {6, 7, 8}, {9, 10, 11, 12}};
for (const TensorShape& tshape : kTensorShapes) {
Tensor indices, values, shape;
GenerateRandomSparseTensor(kNumNonZeros, tshape, indices, values, shape);
TF_EXPECT_OK((ValidateSparseTensor(indices, values, shape)));
}
}
TEST(ValidateSparseTensorTest, InvalidIndicesRankFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
// Indices tensor must be rank 2, so try rank 0, 1, 3.
const TensorShape kInvalidIndicesShapes[] = {
{}, {kNumNonZeros}, {kNumNonZeros, kNumDims, 4}};
for (const TensorShape& invalid_shape : kInvalidIndicesShapes) {
const Tensor indices = Tensor(DT_INT64, invalid_shape);
const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse indices must be rank 2 .*")));
}
}
TEST(ValidateSparseTensorTest, InvalidValuesRankFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
// Values tensor must be rank 1, so try rank 0, 2.
const TensorShape kInvalidValuesShapes[] = {{}, {kNumNonZeros, 2}};
for (const TensorShape& invalid_shape : kInvalidValuesShapes) {
const Tensor indices =
Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims}));
const Tensor values = Tensor(DT_FLOAT, invalid_shape);
const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse values must be rank 1 .*")));
}
}
TEST(ValidateSparseTensorTest, InvalidShapeRankFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
// Shape tensor must be rank 1, so try rank 0, 2.
const TensorShape kInvalidShapeShapes[] = {{}, {kNumDims, 2}};
for (const TensorShape& invalid_shape : kInvalidShapeShapes) {
const Tensor indices =
Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims}));
const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
const Tensor shape = Tensor(DT_INT64, invalid_shape);
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Sparse shape must be rank 1 .*")));
}
}
TEST(ValidateSparseTensorTest, IncompatibleShapesFails) {
constexpr int kNumNonZeros = 1000;
constexpr int kNumDims = 3;
const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
// Indices and values must have the same size in dimension 0 (nnz).
{
const Tensor indices =
Tensor(DT_INT64, TensorShape({kNumNonZeros + 1, kNumDims}));
EXPECT_THAT((ValidateSparseTensor(indices, values, shape)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Number of elements in indices .* and "
"values .* do not match")));
}
// Each index tuple must have the same size in dimension 1 as the dense
// tensor shape (ndims).
{
const Tensor indices =
Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims + 1}));
EXPECT_THAT(
(ValidateSparseTensor(indices, values, shape)),
StatusIs(error::INVALID_ARGUMENT,
MatchesRegex("Index rank .* and shape rank .* do not match")));
}
}
} // namespace
} // namespace sparse_utils
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册