提交 423c1eea 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

BREAKING CHANGE: Fix semantic error in how maybe_batch* handles sparse tensors.

PiperOrigin-RevId: 163276613
上级 6028c071
......@@ -492,8 +492,15 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
out_tensor.set_shape([None]) # necessary when t.ndims is unknown
return out_tensor
def _sparse_values_to_keep(t, keep_input):
"""Convert a per-row `keep_input` vector to a per-value one."""
# Get the rows of every value in the sparse Tensor.
row_values = array_ops.reshape(
t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0]
# The value should be kept iff the row should be kept.
return array_ops.gather(keep_input, row_values)
if keep_input.shape.ndims == 1:
t = sparse_ops.sparse_retain(t, keep_input)
t = sparse_ops.sparse_retain(t, _sparse_values_to_keep(t, keep_input))
store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name)
elif enqueue_many:
store_f = _maybe_store_many_sparse
......
......@@ -903,6 +903,29 @@ class BatchTest(test_lib.TestCase):
[sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchCorrectValues(self):
sparse_t = sparse_tensor.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
dense_shape=[2, 4],
values=[5, 4, 7, 2])
keep = constant_op.constant([True, False])
batched = inp.maybe_batch(
[sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
with self.test_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
batched_np = batched.eval()
coord.request_stop()
for thread in threads:
thread.join()
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
self.assertAllEqual([5, 4], batched_np.values)
self.assertAllEqual([1, 4], batched_np.dense_shape)
class BatchJoinTest(test_lib.TestCase):
......@@ -1457,6 +1480,29 @@ class BatchJoinTest(test_lib.TestCase):
[[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
def testMaybeBatchCorrectValues(self):
sparse = sparse_tensor.SparseTensor(
indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
dense_shape=[2, 4],
values=[5, 4, 7, 2])
keep = constant_op.constant([True, False])
batched = inp.maybe_batch_join(
[[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
with self.test_session():
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
batched_np = batched.eval()
coord.request_stop()
for thread in threads:
thread.join()
self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
self.assertAllEqual([5, 4], batched_np.values)
self.assertAllEqual([1, 4], batched_np.dense_shape)
class ShuffleBatchTest(test_lib.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册