提交 85f46438 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Avoid unnecessary Reshapes and Transposes in tf.tensordot.

PiperOrigin-RevId: 286304583
Change-Id: I2e3b112c8c4c0b8225fc5ade2ff48e6c83f40dfd
上级 bda1cfaa
......@@ -4133,7 +4133,14 @@ def tensordot(a, b, axes, name=None):
prod_axes = int(np.prod([shape_a[i] for i in axes]))
perm = list(axes) + free if flipped else free + list(axes)
new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
if (perm != np.arange(len(shape_a))).any():
a_trans = array_ops.transpose(a, perm)
else:
a_trans = a
if a_trans.get_shape().as_list() != new_shape:
reshaped_a = array_ops.reshape(a_trans, new_shape)
else:
reshaped_a = a_trans
return reshaped_a, free_dims, free_dims
else:
if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)):
......@@ -4209,7 +4216,12 @@ def tensordot(a, b, axes, name=None):
b, b_axes, True)
ab_matmul = matmul(a_reshape, b_reshape)
if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name)
if (ab_matmul.get_shape().is_fully_defined() and
ab_matmul.get_shape().as_list() == a_free_dims + b_free_dims):
return ab_matmul
else:
return array_ops.reshape(
ab_matmul, a_free_dims + b_free_dims, name=name)
else:
a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32)
b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册