提交 ca3d4fe7 编写于 作者: E Edward Loper 提交者: TensorFlower Gardener

Add RaggedTensor.uniform_row_length property, which is a scalar Tensor...

Add RaggedTensor.uniform_row_length property, which is a scalar Tensor containing the length of each row in a RaggedTensor (if the rows are uniform), or None (if rows are ragged).

PiperOrigin-RevId: 295051896
Change-Id: Ie4ec9c33676339d81ee7ebf54bd2932d0cf81898
上级 93fa50ef
......@@ -1069,6 +1069,32 @@ class RaggedTensor(composite_tensor.CompositeTensor):
"""
return self._row_splits
@property
def uniform_row_length(self):
"""The length of each row in this ragged tensor, or None if rows are ragged.
>>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
>>> print(rt1.uniform_row_length) # rows are ragged.
None
>>> rt2 = tf.RaggedTensor.from_uniform_row_length(
... values=rt1, uniform_row_length=2)
>>> print(rt2)
<tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
>>> print(rt2.uniform_row_length) # rows are not ragged (all have size 2).
tf.Tensor(2, shape=(), dtype=int64)
A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged)
if it can be determined statically (at graph construction time) that the
rows all have the same length.
Returns:
A scalar integer `Tensor`, specifying the length of every row in this
ragged tensor (for ragged tensors whose rows are uniform); or `None`
(for ragged tensors whose rows are ragged).
"""
return self._uniform_row_length
@property
def flat_values(self):
"""The innermost `values` tensor for this ragged tensor.
......
......@@ -477,6 +477,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(a1, a2)
self.assertEqual(a1.shape.as_list(), [8, 2])
self.assertEqual(a2.shape.as_list(), [8, 2])
self.assertAllEqual(a1.uniform_row_length, 2)
b1 = RaggedTensor.from_uniform_row_length(a1, 2)
b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4)
......@@ -485,6 +486,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(b1, b2)
self.assertEqual(b1.shape.as_list(), [4, 2, 2])
self.assertEqual(b2.shape.as_list(), [4, 2, 2])
self.assertAllEqual(b1.uniform_row_length, 2)
c1 = RaggedTensor.from_uniform_row_length(b1, 2)
c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2)
......@@ -493,11 +495,13 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(c1, c2)
self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2])
self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2])
self.assertAllEqual(c1.uniform_row_length, 2)
def testFromUniformRowLengthWithEmptyValues(self):
empty_values = []
a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10)
self.assertEqual(a.shape.as_list(), [10, 0])
self.assertAllEqual(a.uniform_row_length, 0)
b = RaggedTensor.from_uniform_row_length(a, 2)
self.assertEqual(b.shape.as_list(), [5, 2, 0])
......
......@@ -27,6 +27,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "uniform_row_length"
mtype: "<type \'property\'>"
}
member {
name: "values"
mtype: "<type \'property\'>"
......
......@@ -27,6 +27,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "uniform_row_length"
mtype: "<type \'property\'>"
}
member {
name: "values"
mtype: "<type \'property\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册