提交 ec0b5d1c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Pfor: make unsorted_segment_sum converter safe for int64 types.

PiperOrigin-RevId: 258450422
上级 02433bb9
......@@ -422,9 +422,13 @@ class MathTest(PForTestCase):
def test_unsorted_segment_sum(self):
t = random_ops.random_uniform([3, 3, 2])
segment_ids = constant_op.constant([[0, 0, 2], [0, 1, 2], [2, 2, 2]])
num_segments = 3
for segment_ids_dtype in (dtypes.int32, dtypes.int64):
for num_segments_dtype in (dtypes.int32, dtypes.int64):
segment_ids = constant_op.constant([[0, 0, 2], [0, 1, 2], [2, 2, 2]],
dtype=segment_ids_dtype)
num_segments = constant_op.constant(3, dtype=num_segments_dtype)
# pylint: disable=cell-var-from-loop
def loop_fn(i):
data = array_ops.gather(t, i)
data_0 = array_ops.gather(t, 0)
......@@ -433,6 +437,7 @@ class MathTest(PForTestCase):
return (math_ops.unsorted_segment_sum(data, seg_ids, num_segments),
math_ops.unsorted_segment_sum(data_0, seg_ids, num_segments),
math_ops.unsorted_segment_sum(data, seg_ids_0, num_segments))
# pylint: enable=cell-var-from-loop
self._test_loop_fn(loop_fn, 3, [dtypes.float32] * 3)
......
......@@ -2198,10 +2198,14 @@ def _convert_unsortedsegmentsum(pfor_input):
segment_ids = pfor_input.stacked_input(1)
# TODO(agarwal): handle stacked?
num_segments = pfor_input.unstacked_input(2)
segment_shape = array_ops.shape(segment_ids)
if segment_ids.dtype != num_segments.dtype:
segment_ids = math_ops.cast(segment_ids, dtypes.int64)
num_segments = math_ops.cast(num_segments, dtypes.int64)
dtype = segment_ids.dtype
segment_shape = array_ops.shape(segment_ids, out_type=dtype)
n = segment_shape[0]
ones = array_ops.ones_like(segment_shape)[1:]
segment_offset = num_segments * math_ops.range(n)
ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:]
segment_offset = num_segments * math_ops.range(n, dtype=dtype)
segment_offset = array_ops.reshape(segment_offset,
array_ops.concat([[n], ones], axis=0))
segment_ids += segment_offset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册