From 391ec802c94d084e7f32b7c66e712c7a68e29556 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 23 Sep 2022 08:40:41 -0700 Subject: [PATCH] Fix missing sparse matrix crash. Calling a sparse matrix op with no matrix currently causes a crash. Here we check and return a non-ok status. PiperOrigin-RevId: 476379116 --- tensorflow/core/kernels/sparse/sparse_matrix.h | 7 +++++++ .../linalg/sparse/csr_sparse_matrix_ops_test.py | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow/core/kernels/sparse/sparse_matrix.h b/tensorflow/core/kernels/sparse/sparse_matrix.h index b2ec20acc65..15fbe5df6eb 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix.h +++ b/tensorflow/core/kernels/sparse/sparse_matrix.h @@ -25,10 +25,12 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -633,6 +635,11 @@ template Status ExtractVariantFromInput(OpKernelContext* ctx, int index, const T** value) { const Tensor& input_t = ctx->input(index); + if (!TensorShapeUtils::IsScalar(input_t.shape())) { + return errors::InvalidArgument( + "Invalid input matrix: Shape must be rank 0 but is rank ", + input_t.dims()); + } const Variant& input_variant = input_t.scalar()(); *value = input_variant.get(); if (*value == nullptr) { diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index 035791ce0a5..d129bea768e 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -1313,6 +1313,16 @@ class CSRSparseMatrixOpsTest(test.TestCase): self.assertLess(cholesky_with_amd_nnz_value, cholesky_without_ordering_nnz_value) + @test_util.run_in_graph_and_eager_modes + def testNoMatrixNoCrash(self): + # Round-about way of creating an empty variant tensor that works in both + # graph and eager modes. + no_matrix = array_ops.reshape(dense_to_csr_sparse_matrix([[0.0]]), [1])[0:0] + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + "(Invalid input matrix)|(Shape must be rank 0)"): + sparse_csr_matrix_ops.sparse_matrix_nnz(no_matrix) + class CSRSparseMatrixOpsBenchmark(test.Benchmark): -- GitLab