提交 b00e9404 编写于 作者: I Illia Polosukhin 提交者: TensorFlower Gardener

Fixes #4226: Allow to specify partial shape in sparse_placeholder.

Change: 137854718
上级 6172351b
......@@ -599,7 +599,37 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.shape, shape)
def testFeedSparePlaceholderConstantShape(self):
def testFeedSparsePlaceholderPartialShape(self):
with session.Session() as s:
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
values = np.array([1.0, 2.0]).astype(np.float32)
shape = np.array([7, 9, 2]).astype(np.int64)
sp = array_ops.sparse_placeholder(
shape=[None, 9, 2], dtype=np.float32, name='placeholder1')
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
sp_shape = array_ops.identity(sp.shape)
sp2 = ops.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
[sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
[sp_indices, sp_values, sp_shape],
{sp: ops.SparseTensorValue(indices, values, shape)})
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
sp2_out = s.run(sp2, {sp: ops.SparseTensorValue(indices, values, shape)})
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.shape, shape)
def testFeedSparsePlaceholderConstantShape(self):
with session.Session() as s:
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
values = np.array([1.0, 2.0]).astype(np.float32)
......
......@@ -1445,6 +1445,17 @@ def placeholder(dtype, shape=None, name=None):
return ret
# pylint: disable=redefined-outer-name
def _normalize_sparse_shape(shape, name):
"""Takes numpy array or Tensor or None and returns either None or Tensor."""
if shape is None: return None
if not isinstance(shape, ops.Tensor):
for el in shape:
if el is None:
return None
return ops.convert_to_tensor(shape, name=name)
def sparse_placeholder(dtype, shape=None, name=None):
"""Inserts a placeholder for a sparse tensor that will be always fed.
......@@ -1484,13 +1495,10 @@ def sparse_placeholder(dtype, shape=None, name=None):
A `SparseTensor` that may be used as a handle for feeding a value, but not
evaluated directly.
"""
shape_name = (name + "/shape") if name is not None else None
shape = _normalize_sparse_shape(shape, shape_name)
if shape is None:
shape = placeholder(
dtypes.int64, shape=[None],
name=(name + "/shape") if name is not None else None)
else:
shape = ops.convert_to_tensor(
shape, name=(name + "/shape") if name is not None else None)
shape = placeholder(dtypes.int64, shape=[None], name=shape_name)
return ops.SparseTensor(
values=placeholder(
dtype, shape=[None],
......@@ -1500,6 +1508,7 @@ def sparse_placeholder(dtype, shape=None, name=None):
name=(name + "/indices") if name is not None else None),
shape=shape
)
# pylint: enable=redefined-outer-name
def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invalid-name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册