提交 a571aba2 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Internal Change

PiperOrigin-RevId: 224891138
上级 c25282f9
......@@ -39,8 +39,7 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
kwargs=None):
kwargs = kwargs or {}
result = ragged.map_flat_values(op, *args, **kwargs)
with self.test_session():
self.assertRaggedEqual(result, expected)
self.assertRaggedEqual(result, expected)
def testDocStringExamples(self):
"""Test the examples in apply_op_to_ragged_values.__doc__."""
......@@ -48,10 +47,9 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
v1 = ragged.map_flat_values(array_ops.ones_like, rt)
v2 = ragged.map_flat_values(math_ops.multiply, rt, rt)
v3 = ragged.map_flat_values(math_ops.add, rt, 5)
with self.test_session():
self.assertRaggedEqual(v1, [[1, 1, 1], [], [1, 1], [1]])
self.assertRaggedEqual(v2, [[1, 4, 9], [], [16, 25], [36]])
self.assertRaggedEqual(v3, [[6, 7, 8], [], [9, 10], [11]])
self.assertRaggedEqual(v1, [[1, 1, 1], [], [1, 1], [1]])
self.assertRaggedEqual(v2, [[1, 4, 9], [], [16, 25], [36]])
self.assertRaggedEqual(v3, [[6, 7, 8], [], [9, 10], [11]])
def testOpWithSingleRaggedTensorArg(self):
tensor = ragged.constant([[1, 2, 3], [], [4, 5]])
......@@ -122,9 +120,8 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
# ragged_rank=0
x0 = [3, 1, 4, 1, 5, 9, 2, 6, 5]
y0 = [1, 2, 3, 4, 5, 6, 7, 8, 9]
with self.test_session():
self.assertRaggedEqual(
math_ops.multiply(x0, y0), [3, 2, 12, 4, 25, 54, 14, 48, 45])
self.assertRaggedEqual(
math_ops.multiply(x0, y0), [3, 2, 12, 4, 25, 54, 14, 48, 45])
# ragged_rank=1
x1 = ragged.constant([[3, 1, 4], [], [1, 5], [9, 2], [6, 5]])
......
......@@ -92,8 +92,7 @@ class RaggedUtilTest(ragged_test_util.RaggedTensorTestCase,
])
def testRepeat(self, data, repeats, expected, axis=None):
result = ragged_util.repeat(data, repeats, axis)
with self.test_session():
self.assertAllEqual(result, expected)
self.assertAllEqual(result, expected)
@parameterized.parameters([
dict(mode=mode, **args)
......@@ -158,8 +157,7 @@ class RaggedUtilTest(ragged_test_util.RaggedTensorTestCase,
repeats = array_ops.placeholder_with_default(repeats, None)
result = ragged_util.repeat(data, repeats, axis)
with self.test_session():
self.assertAllEqual(result, expected)
self.assertAllEqual(result, expected)
@parameterized.parameters([
dict(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册