diff --git a/tensorflow/core/kernels/sparse/sparse_matrix.h b/tensorflow/core/kernels/sparse/sparse_matrix.h index b2ec20acc658489c81dd04a6ccee98e272af80cb..15fbe5df6ebe2936d65f373cb2f7711c8b735336 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix.h +++ b/tensorflow/core/kernels/sparse/sparse_matrix.h @@ -25,10 +25,12 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -633,6 +635,11 @@ template Status ExtractVariantFromInput(OpKernelContext* ctx, int index, const T** value) { const Tensor& input_t = ctx->input(index); + if (!TensorShapeUtils::IsScalar(input_t.shape())) { + return errors::InvalidArgument( + "Invalid input matrix: Shape must be rank 0 but is rank ", + input_t.dims()); + } const Variant& input_variant = input_t.scalar()(); *value = input_variant.get(); if (*value == nullptr) { diff --git a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py index 035791ce0a5bf545d5cb05e9def9ac0398336a12..d129bea768e85fb493e16ef7c1fc3848eac6e6ab 100644 --- a/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/sparse/csr_sparse_matrix_ops_test.py @@ -1313,6 +1313,16 @@ class CSRSparseMatrixOpsTest(test.TestCase): self.assertLess(cholesky_with_amd_nnz_value, cholesky_without_ordering_nnz_value) + @test_util.run_in_graph_and_eager_modes + def testNoMatrixNoCrash(self): + # Round-about way of creating an empty variant tensor that works in both + # graph and eager modes. + no_matrix = array_ops.reshape(dense_to_csr_sparse_matrix([[0.0]]), [1])[0:0] + with self.assertRaisesRegex( + (ValueError, errors.InvalidArgumentError), + "(Invalid input matrix)|(Shape must be rank 0)"): + sparse_csr_matrix_ops.sparse_matrix_nnz(no_matrix) + class CSRSparseMatrixOpsBenchmark(test.Benchmark):