提交 10332bb8 编写于 作者: A Akshay Modi 提交者: TensorFlower Gardener

PFor inputs should be ndarrays.

PiperOrigin-RevId: 328068227
Change-Id: Ia084d946f3a0e5d071d7e8fec4263d1da26d9671
上级 01b030b7
......@@ -323,6 +323,11 @@ class InteropTest(tf.test.TestCase):
self.assertIsInstance(c, np.ndarray)
self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))
c = tf.vectorized_map(lambda x: x.T, a)
self.assertIsInstance(c, np.ndarray)
self.assertEqual(c.shape, (batch_size, 32, 32))
def testJacobian(self):
with tf.GradientTape() as g:
x = np.asarray([1., 2.])
......
......@@ -357,7 +357,10 @@ def _broadcasting_gather(x, i):
i = 0
elif static_first_dim is None:
i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0)
return array_ops.gather(x, i)
result = array_ops.gather(x, i)
if isinstance(x, np_arrays.ndarray):
result = np_arrays.ndarray.from_tensor(result)
return result
@tf_export("vectorized_map")
......@@ -450,7 +453,11 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""
elems = nest.map_structure(ops.convert_to_tensor, elems)
def _convert_to_tensor_or_ndarray(x):
if isinstance(x, np_arrays.ndarray):
return x
return ops.convert_to_tensor(x)
elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
def loop_fn(i):
gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
......@@ -459,9 +466,13 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
# Extract batch size from the maximum first dimension of any element.
flat_elems = nest.flatten(elems)
static_first_dims = [elem.shape.as_list()[0]
if elem.shape.rank is not None else None
for elem in flat_elems]
def _get_shape(x):
if isinstance(x, np_arrays.ndarray):
x = x.data
if x.shape.rank is None:
return None
return x.shape.as_list()[0]
static_first_dims = [_get_shape(elem) for elem in flat_elems]
if any([s is None for s in static_first_dims]):
batch_size = math_ops.reduce_max(
[array_ops.shape(elem)[0] for elem in flat_elems])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册