From ca3d4fe7cd22d5257c31a8481bc8bbf7680b7b95 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Thu, 13 Feb 2020 19:05:02 -0800 Subject: [PATCH] Add RaggedTensor.uniform_row_length property, which is a scalar Tensor containing the length of each row in a RaggedTensor (if the rows are uniform), or None (if rows are ragged). PiperOrigin-RevId: 295051896 Change-Id: Ie4ec9c33676339d81ee7ebf54bd2932d0cf81898 --- tensorflow/python/ops/ragged/ragged_tensor.py | 26 +++++++++++++++++++ .../python/ops/ragged/ragged_tensor_test.py | 4 +++ .../golden/v1/tensorflow.-ragged-tensor.pbtxt | 4 +++ .../golden/v2/tensorflow.-ragged-tensor.pbtxt | 4 +++ 4 files changed, 38 insertions(+) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index ad0c5e2a863..5746137034f 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -1069,6 +1069,32 @@ class RaggedTensor(composite_tensor.CompositeTensor): """ return self._row_splits + @property + def uniform_row_length(self): + """The length of each row in this ragged tensor, or None if rows are ragged. + + >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) + >>> print(rt1.uniform_row_length) # rows are ragged. + None + + >>> rt2 = tf.RaggedTensor.from_uniform_row_length( + ... values=rt1, uniform_row_length=2) + >>> print(rt2) + + >>> print(rt2.uniform_row_length) # rows are not ragged (all have size 2). + tf.Tensor(2, shape=(), dtype=int64) + + A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged) + if it can be determined statically (at graph construction time) that the + rows all have the same length. + + Returns: + A scalar integer `Tensor`, specifying the length of every row in this + ragged tensor (for ragged tensors whose rows are uniform); or `None` + (for ragged tensors whose rows are ragged). + """ + return self._uniform_row_length + @property def flat_values(self): """The innermost `values` tensor for this ragged tensor. diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index c74617aa314..6bc066e5d84 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -477,6 +477,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(a1, a2) self.assertEqual(a1.shape.as_list(), [8, 2]) self.assertEqual(a2.shape.as_list(), [8, 2]) + self.assertAllEqual(a1.uniform_row_length, 2) b1 = RaggedTensor.from_uniform_row_length(a1, 2) b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4) @@ -485,6 +486,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(b1, b2) self.assertEqual(b1.shape.as_list(), [4, 2, 2]) self.assertEqual(b2.shape.as_list(), [4, 2, 2]) + self.assertAllEqual(b1.uniform_row_length, 2) c1 = RaggedTensor.from_uniform_row_length(b1, 2) c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2) @@ -493,11 +495,13 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(c1, c2) self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2]) self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2]) + self.assertAllEqual(c1.uniform_row_length, 2) def testFromUniformRowLengthWithEmptyValues(self): empty_values = [] a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10) self.assertEqual(a.shape.as_list(), [10, 0]) + self.assertAllEqual(a.uniform_row_length, 0) b = RaggedTensor.from_uniform_row_length(a, 2) self.assertEqual(b.shape.as_list(), [5, 2, 0]) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt index 36e0f543540..c64909d45f5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt @@ -27,6 +27,10 @@ tf_class { name: "shape" mtype: "" } + member { + name: "uniform_row_length" + mtype: "" + } member { name: "values" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt index 36e0f543540..c64909d45f5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt @@ -27,6 +27,10 @@ tf_class { name: "shape" mtype: "" } + member { + name: "uniform_row_length" + mtype: "" + } member { name: "values" mtype: "" -- GitLab