From 657c4432899cb7430e380c1d465d789ced68ae6f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Aug 2017 18:32:20 -0700 Subject: [PATCH] Make indices placeholder in sparse_placeholder maintain rank information. PiperOrigin-RevId: 165389240 --- tensorflow/python/kernel_tests/sparse_ops_test.py | 3 +++ tensorflow/python/ops/array_ops.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index c70f152af8e..51bfceee01f 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -911,14 +911,17 @@ class SparsePlaceholderTest(test.TestCase): def testPlaceholder(self): foo = array_ops.sparse_placeholder(dtypes.float32, shape=(10, 47)) self.assertAllEqual([10, 47], foo.get_shape()) + self.assertAllEqual([None, 2], foo.indices.get_shape().as_list()) def testPartialShapePlaceholder(self): foo = array_ops.sparse_placeholder(dtypes.float32, shape=(None, 47)) self.assertAllEqual([None, None], foo.get_shape().as_list()) + self.assertAllEqual([None, 2], foo.indices.get_shape().as_list()) def testNoShapePlaceholder(self): foo = array_ops.sparse_placeholder(dtypes.float32, shape=None) self.assertAllEqual(None, foo.get_shape()) + self.assertAllEqual([None, None], foo.indices.get_shape().as_list()) if __name__ == "__main__": diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 0042f929ee7..0570b34a603 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1635,8 +1635,7 @@ def sparse_placeholder(dtype, shape=None, name=None): shape=[None], name=(name + "/values") if name is not None else None), indices=placeholder( - dtypes.int64, - shape=[None, None], + dtypes.int64, shape=[None, rank], name=(name + "/indices") if name is not None else None), dense_shape=shape) -- GitLab