提交 81e345d2 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix the inferred shape of SparseTensor after a tf.train.batch dequeue.

Change: 119755164
上级 eddacad7
......@@ -322,24 +322,23 @@ class _SparseMetaData(object):
def _serialize_sparse_tensors(tensor_list, enqueue_many):
"""Serialize SparseTensors for feeding into batch, etc."""
sparse_info_list = [
_SparseMetaData(sparse=True,
dtype=t.dtype,
rank=t.shape.get_shape().with_rank(1)[0])
if isinstance(t, ops.SparseTensor)
else _SparseMetaData(False, None, None)
for t in tensor_list]
def _maybe_serialize(t, sparse):
if not sparse:
def _sparse_meta_data(t):
if not isinstance(t, ops.SparseTensor):
return _SparseMetaData(False, None, None)
rank = t.shape.get_shape().with_rank(1)[0]
if enqueue_many:
rank -= 1
return _SparseMetaData(sparse=True, dtype=t.dtype, rank=rank)
def _maybe_serialize(t):
if not isinstance(t, ops.SparseTensor):
return t
return (sparse_ops.serialize_many_sparse(t) if enqueue_many
else sparse_ops.serialize_sparse(t))
serialized_list = [
_maybe_serialize(t, info.sparse) for (t, info)
in zip(tensor_list, sparse_info_list)]
serialized_list = [_maybe_serialize(t) for t in tensor_list]
sparse_info_list = [_sparse_meta_data(t) for t in tensor_list]
return serialized_list, sparse_info_list
......@@ -368,7 +367,7 @@ def _deserialize_sparse_tensors(serialized_list, sparse_info_list):
if not received_sequence:
serialized_list = (serialized_list,)
tensors = [
sparse_ops.deserialize_many_sparse(s, info.dtype, info.rank.value)
sparse_ops.deserialize_many_sparse(s, info.dtype, info.rank.value + 1)
if info.sparse else s
for (s, info)
in zip(serialized_list, sparse_info_list)]
......
......@@ -526,6 +526,18 @@ class BatchTest(tf.test.TestCase):
with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
tf.train.batch([x], batch_size=2)
def testBatchedSparseTensorInferedShape(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
self.assertAllEqual(sparse.shape.get_shape().as_list(), [1])
batched = tf.train.batch([sparse], batch_size=2)
self.assertAllEqual(batched.shape.get_shape().as_list(), [2])
def testBatchedSparseTensorInferedShapeEnqueueMany(self):
sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
self.assertAllEqual(sparse.shape.get_shape().as_list(), [1])
batched = tf.train.batch([sparse], batch_size=2, enqueue_many=True)
self.assertAllEqual(batched.shape.get_shape().as_list(), [1])
class BatchJoinTest(tf.test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册